Skip to main content

cognate_core/
lib.rs

1//! Cognate Core — HTTP client, traits, and base types for LLM providers.
2//!
3//! This crate provides the foundational abstractions for building
4//! provider-agnostic LLM applications with type-safe interfaces and
5//! zero-cost abstractions.
6//!
7//! # Quick start
8//!
9//! ```rust,no_run
10//! use cognate_core::{Provider, Request, Message};
11//!
12//! async fn run<P: Provider>(provider: &P) -> cognate_core::Result<()> {
13//!     let response = provider
14//!         .complete(
15//!             Request::new()
16//!                 .with_model("gpt-4o-mini")
17//!                 .with_message(Message::user("Hello!")),
18//!         )
19//!         .await?;
20//!     println!("{}", response.content());
21//!     Ok(())
22//! }
23//! ```
24#![warn(missing_docs)]
25
26use async_trait::async_trait;
27use futures::stream::BoxStream;
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30
31pub mod error;
32pub mod middleware;
33pub mod mock;
34pub mod ratelimit;
35pub mod types;
36
37pub use error::{Error, Result};
38pub use middleware::{Layer, Middleware, ProviderExt};
39pub use mock::MockProvider;
40pub use ratelimit::TokenBucket;
41
42// ─── Provider trait ────────────────────────────────────────────────────────
43
44/// Core trait for all LLM providers.
45///
46/// Implement this trait to add support for a new LLM provider.
47/// The trait is object-safe and supports both full completion and
48/// streaming responses.
49///
50/// # Example
51///
52/// ```rust,no_run
53/// use cognate_core::{Provider, Request, Message};
54///
55/// async fn example<P: Provider>(provider: &P) -> cognate_core::Result<()> {
56///     let request = Request::new()
57///         .with_model("gpt-4o")
58///         .with_messages(vec![
59///             Message::system("You are a helpful assistant"),
60///             Message::user("Hello!"),
61///         ]);
62///
63///     let response = provider.complete(request).await?;
64///     println!("{}", response.content());
65///     Ok(())
66/// }
67/// ```
68#[async_trait]
69pub trait Provider: Send + Sync {
70    /// Send a completion request and wait for the full response.
71    async fn complete(&self, req: Request) -> Result<Response>;
72
73    /// Send a completion request and return a streaming response.
74    ///
75    /// Returns a stream of [`Chunk`]s as they are generated by the provider.
76    ///
77    /// # Example
78    ///
79    /// ```rust,no_run
80    /// use cognate_core::{Provider, Request, Message};
81    /// use futures::StreamExt;
82    ///
83    /// async fn stream_example<P: Provider>(provider: &P) -> cognate_core::Result<()> {
84    ///     let mut stream = provider
85    ///         .stream(Request::new().with_model("gpt-4o").with_message(Message::user("Hi")))
86    ///         .await?;
87    ///     while let Some(chunk) = stream.next().await {
88    ///         print!("{}", chunk?.content());
89    ///     }
90    ///     Ok(())
91    /// }
92    /// ```
93    async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>>;
94}
95
96/// Trait for providers that can generate embedding vectors.
97///
98/// Implement this alongside [`Provider`] if your backend supports embeddings.
99///
100/// # Example
101///
102/// ```rust,no_run
103/// use cognate_core::EmbeddingProvider;
104///
105/// async fn embed<E: EmbeddingProvider>(embedder: &E) -> cognate_core::Result<()> {
106///     let vecs = embedder.embed(vec!["Hello world".to_string()]).await?;
107///     println!("Embedding dimension: {}", vecs[0].len());
108///     Ok(())
109/// }
110/// ```
111#[async_trait]
112pub trait EmbeddingProvider: Send + Sync {
113    /// Generate embedding vectors for the given list of input strings.
114    ///
115    /// Returns one vector per input in the same order as `inputs`.
116    async fn embed(&self, inputs: Vec<String>) -> Result<Vec<Vec<f32>>>;
117}
118
119// ─── Message / Role ────────────────────────────────────────────────────────
120
121/// A single message in a conversation.
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
123pub struct Message {
124    /// The role of the message sender.
125    pub role: Role,
126    /// The text content of the message.
127    ///
128    /// For assistant messages that contain only tool calls, this may be empty.
129    pub content: String,
130    /// Optional name used to distinguish multiple participants of the same role.
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub name: Option<String>,
133    /// Tool calls requested by the assistant, if any.
134    ///
135    /// Present on messages with `role = Assistant` when the model wants to
136    /// invoke one or more tools.
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub tool_calls: Option<Vec<ToolCall>>,
139    /// The tool-call ID this message is responding to.
140    ///
141    /// Must be set on messages with `role = Tool` so the provider can
142    /// correlate the result with the originating call.
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub tool_call_id: Option<String>,
145}
146
147/// The role of a message sender.
148#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
149#[serde(rename_all = "lowercase")]
150pub enum Role {
151    /// A high-level instruction that shapes the assistant's behaviour.
152    System,
153    /// A message from the human end of the conversation.
154    User,
155    /// A message generated by the assistant.
156    Assistant,
157    /// A legacy function-call result (OpenAI function-calling v1).
158    #[serde(rename = "function")]
159    Function,
160    /// A tool-call result sent back by the client.
161    Tool,
162}
163
164impl Message {
165    /// Create a system message.
166    pub fn system(content: impl Into<String>) -> Self {
167        Self {
168            role: Role::System,
169            content: content.into(),
170            name: None,
171            tool_calls: None,
172            tool_call_id: None,
173        }
174    }
175
176    /// Create a user message.
177    pub fn user(content: impl Into<String>) -> Self {
178        Self {
179            role: Role::User,
180            content: content.into(),
181            name: None,
182            tool_calls: None,
183            tool_call_id: None,
184        }
185    }
186
187    /// Create an assistant message.
188    pub fn assistant(content: impl Into<String>) -> Self {
189        Self {
190            role: Role::Assistant,
191            content: content.into(),
192            name: None,
193            tool_calls: None,
194            tool_call_id: None,
195        }
196    }
197
198    /// Create a tool-result message.
199    ///
200    /// `tool_call_id` must match the `id` of the [`ToolCall`] being answered.
201    pub fn tool_result(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
202        Self {
203            role: Role::Tool,
204            content: content.into(),
205            name: None,
206            tool_calls: None,
207            tool_call_id: Some(tool_call_id.into()),
208        }
209    }
210}
211
212// ─── Tool calling ──────────────────────────────────────────────────────────
213
214/// A tool invocation requested by the model.
215#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
216pub struct ToolCall {
217    /// Unique identifier for this call, used to correlate the result.
218    pub id: String,
219    /// The type of call — currently always `"function"`.
220    #[serde(rename = "type")]
221    pub call_type: String,
222    /// The function the model wants to call.
223    pub function: ToolCallFunction,
224}
225
226/// The function component of a [`ToolCall`].
227#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
228pub struct ToolCallFunction {
229    /// Name of the function to invoke.
230    pub name: String,
231    /// JSON-encoded arguments string, e.g. `"{\"query\":\"Rust\"}"`.
232    pub arguments: String,
233}
234
235// ─── Request ───────────────────────────────────────────────────────────────
236
237/// A completion request sent to a provider.
238#[derive(Debug, Clone, Serialize, Deserialize, Default)]
239pub struct Request {
240    /// The model identifier, e.g. `"gpt-4o"` or `"claude-3-5-sonnet-20241022"`.
241    pub model: String,
242    /// The conversation history, including any system prompt.
243    pub messages: Vec<Message>,
244    /// Sampling temperature in `[0.0, 2.0]`.
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub temperature: Option<f32>,
247    /// Maximum tokens to generate.
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub max_tokens: Option<u32>,
250    /// Nucleus sampling parameter.
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub top_p: Option<f32>,
253    /// Frequency penalty in `[-2.0, 2.0]` (OpenAI only).
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub frequency_penalty: Option<f32>,
256    /// Presence penalty in `[-2.0, 2.0]` (OpenAI only).
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub presence_penalty: Option<f32>,
259    /// Stop sequences.
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub stop: Option<Vec<String>>,
262    /// Whether to stream the response. Providers handle this internally.
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub stream: Option<bool>,
265    /// Structured output format (`json_object` etc.).
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub response_format: Option<ResponseFormat>,
268    /// Provider-specific extra parameters (e.g. `tools`, `tool_choice`).
269    #[serde(skip_serializing_if = "HashMap::is_empty", default)]
270    pub extra: HashMap<String, serde_json::Value>,
271}
272
273impl Request {
274    /// Create a new empty request.
275    pub fn new() -> Self {
276        Self::default()
277    }
278
279    /// Set the model identifier.
280    pub fn with_model(mut self, model: impl Into<String>) -> Self {
281        self.model = model.into();
282        self
283    }
284
285    /// Set the full message list.
286    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
287        self.messages = messages;
288        self
289    }
290
291    /// Append a single message.
292    pub fn with_message(mut self, message: Message) -> Self {
293        self.messages.push(message);
294        self
295    }
296
297    /// Set the sampling temperature.
298    pub fn with_temperature(mut self, temperature: f32) -> Self {
299        self.temperature = Some(temperature);
300        self
301    }
302
303    /// Set the maximum number of tokens to generate.
304    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
305        self.max_tokens = Some(max_tokens);
306        self
307    }
308
309    /// Set the `top_p` nucleus sampling parameter.
310    pub fn with_top_p(mut self, top_p: f32) -> Self {
311        self.top_p = Some(top_p);
312        self
313    }
314
315    /// Enable JSON mode (structured output).
316    pub fn with_json_mode(mut self) -> Self {
317        self.response_format = Some(ResponseFormat::json_object());
318        self
319    }
320
321    /// Insert a provider-specific extra parameter.
322    pub fn with_extra(
323        mut self,
324        key: impl Into<String>,
325        value: impl Into<serde_json::Value>,
326    ) -> Self {
327        self.extra.insert(key.into(), value.into());
328        self
329    }
330}
331
332// ─── ResponseFormat ────────────────────────────────────────────────────────
333
334/// Structured output format specifier.
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct ResponseFormat {
337    /// The format type — e.g. `"json_object"`.
338    #[serde(rename = "type")]
339    pub format_type: String,
340}
341
342impl ResponseFormat {
343    /// Request JSON object output.
344    pub fn json_object() -> Self {
345        Self {
346            format_type: "json_object".to_string(),
347        }
348    }
349}
350
351// ─── Response / Choice / Usage / Chunk ────────────────────────────────────
352
353/// A completed response from a provider.
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct Response {
356    /// Provider-assigned response identifier.
357    pub id: String,
358    /// Model that generated the response.
359    pub model: String,
360    /// One or more completion choices (usually one).
361    pub choices: Vec<Choice>,
362    /// Token usage statistics, if the provider returned them.
363    pub usage: Option<Usage>,
364    /// Unix timestamp of when the response was created.
365    pub created: Option<u64>,
366}
367
368impl Response {
369    /// Return the text content of the first choice.
370    ///
371    /// Returns an empty string if there are no choices.
372    pub fn content(&self) -> &str {
373        self.choices
374            .first()
375            .map(|c| c.message.content.as_str())
376            .unwrap_or("")
377    }
378
379    /// Return token usage statistics, if available.
380    pub fn usage(&self) -> Option<&Usage> {
381        self.usage.as_ref()
382    }
383
384    /// Return the tool calls from the first choice, if any.
385    pub fn tool_calls(&self) -> Option<&Vec<ToolCall>> {
386        self.choices
387            .first()
388            .and_then(|c| c.message.tool_calls.as_ref())
389    }
390}
391
392/// A single completion choice within a [`Response`].
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct Choice {
395    /// Zero-based index of this choice.
396    pub index: u32,
397    /// The message generated for this choice.
398    pub message: Message,
399    /// Reason the model stopped generating, e.g. `"stop"` or `"tool_calls"`.
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub finish_reason: Option<String>,
402}
403
404/// Token usage statistics for a request.
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct Usage {
407    /// Number of tokens in the prompt.
408    pub prompt_tokens: u32,
409    /// Number of tokens generated.
410    pub completion_tokens: u32,
411    /// Total tokens consumed (`prompt_tokens + completion_tokens`).
412    pub total_tokens: u32,
413}
414
415impl Usage {
416    /// Calculate the USD cost of this request.
417    ///
418    /// `prompt_price` and `completion_price` are expressed as USD per 1 000 tokens.
419    pub fn calculate_cost(&self, prompt_price: f64, completion_price: f64) -> f64 {
420        let prompt_cost = (self.prompt_tokens as f64 / 1000.0) * prompt_price;
421        let completion_cost = (self.completion_tokens as f64 / 1000.0) * completion_price;
422        prompt_cost + completion_cost
423    }
424}
425
426/// A single chunk in a streaming response.
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct Chunk {
429    /// Provider-assigned response identifier.
430    pub id: String,
431    /// Model that generated this chunk.
432    pub model: String,
433    /// The incremental content delta.
434    pub delta: Delta,
435    /// Set on the final chunk — e.g. `"stop"` or `"tool_calls"`.
436    pub finish_reason: Option<String>,
437}
438
439impl Chunk {
440    /// Return the incremental text content of this chunk.
441    pub fn content(&self) -> &str {
442        &self.delta.content
443    }
444
445    /// Return `true` if this is the terminal chunk of the stream.
446    pub fn is_finished(&self) -> bool {
447        self.finish_reason.is_some()
448    }
449}
450
451/// The incremental content delta inside a [`Chunk`].
452#[derive(Debug, Clone, Serialize, Deserialize, Default)]
453pub struct Delta {
454    /// Role of the speaker, present only in the first chunk of a response.
455    pub role: Option<Role>,
456    /// Incremental text content generated since the previous chunk.
457    pub content: String,
458}
459
460// ─── ProviderConfig ────────────────────────────────────────────────────────
461
462/// Configuration shared by all provider clients.
463#[derive(Debug, Clone)]
464pub struct ProviderConfig {
465    /// API key used to authenticate requests.
466    pub api_key: String,
467    /// Override for the provider's default base URL.
468    pub base_url: String,
469    /// Request timeout in seconds.
470    pub timeout_seconds: u64,
471    /// Maximum number of automatic retries on transient errors.
472    pub max_retries: u32,
473}
474
475impl ProviderConfig {
476    /// Create a minimal config with only an API key.
477    pub fn new(api_key: impl Into<String>) -> Self {
478        Self {
479            api_key: api_key.into(),
480            base_url: String::new(),
481            timeout_seconds: 60,
482            max_retries: 3,
483        }
484    }
485
486    /// Override the default base URL (useful for proxies or local servers).
487    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
488        self.base_url = url.into();
489        self
490    }
491
492    /// Set the request timeout in seconds.
493    pub fn with_timeout(mut self, seconds: u64) -> Self {
494        self.timeout_seconds = seconds;
495        self
496    }
497
498    /// Set the maximum number of automatic retries.
499    pub fn with_max_retries(mut self, retries: u32) -> Self {
500        self.max_retries = retries;
501        self
502    }
503}