cognate_core/lib.rs
1//! Cognate Core — HTTP client, traits, and base types for LLM providers.
2//!
3//! This crate provides the foundational abstractions for building
4//! provider-agnostic LLM applications with type-safe interfaces and
5//! zero-cost abstractions.
6//!
7//! # Quick start
8//!
9//! ```rust,no_run
10//! use cognate_core::{Provider, Request, Message};
11//!
12//! async fn run<P: Provider>(provider: &P) -> cognate_core::Result<()> {
13//! let response = provider
14//! .complete(
15//! Request::new()
16//! .with_model("gpt-4o-mini")
17//! .with_message(Message::user("Hello!")),
18//! )
19//! .await?;
20//! println!("{}", response.content());
21//! Ok(())
22//! }
23//! ```
24#![warn(missing_docs)]
25
26use async_trait::async_trait;
27use futures::stream::BoxStream;
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30
31pub mod error;
32pub mod middleware;
33pub mod mock;
34pub mod ratelimit;
35pub mod types;
36
37pub use error::{Error, Result};
38pub use middleware::{Layer, Middleware, ProviderExt};
39pub use mock::MockProvider;
40pub use ratelimit::TokenBucket;
41
42// ─── Provider trait ────────────────────────────────────────────────────────
43
44/// Core trait for all LLM providers.
45///
46/// Implement this trait to add support for a new LLM provider.
47/// The trait is object-safe and supports both full completion and
48/// streaming responses.
49///
50/// # Example
51///
52/// ```rust,no_run
53/// use cognate_core::{Provider, Request, Message};
54///
55/// async fn example<P: Provider>(provider: &P) -> cognate_core::Result<()> {
56/// let request = Request::new()
57/// .with_model("gpt-4o")
58/// .with_messages(vec![
59/// Message::system("You are a helpful assistant"),
60/// Message::user("Hello!"),
61/// ]);
62///
63/// let response = provider.complete(request).await?;
64/// println!("{}", response.content());
65/// Ok(())
66/// }
67/// ```
68#[async_trait]
69pub trait Provider: Send + Sync {
70 /// Send a completion request and wait for the full response.
71 async fn complete(&self, req: Request) -> Result<Response>;
72
73 /// Send a completion request and return a streaming response.
74 ///
75 /// Returns a stream of [`Chunk`]s as they are generated by the provider.
76 ///
77 /// # Example
78 ///
79 /// ```rust,no_run
80 /// use cognate_core::{Provider, Request, Message};
81 /// use futures::StreamExt;
82 ///
83 /// async fn stream_example<P: Provider>(provider: &P) -> cognate_core::Result<()> {
84 /// let mut stream = provider
85 /// .stream(Request::new().with_model("gpt-4o").with_message(Message::user("Hi")))
86 /// .await?;
87 /// while let Some(chunk) = stream.next().await {
88 /// print!("{}", chunk?.content());
89 /// }
90 /// Ok(())
91 /// }
92 /// ```
93 async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>>;
94}
95
96/// Trait for providers that can generate embedding vectors.
97///
98/// Implement this alongside [`Provider`] if your backend supports embeddings.
99///
100/// # Example
101///
102/// ```rust,no_run
103/// use cognate_core::EmbeddingProvider;
104///
105/// async fn embed<E: EmbeddingProvider>(embedder: &E) -> cognate_core::Result<()> {
106/// let vecs = embedder.embed(vec!["Hello world".to_string()]).await?;
107/// println!("Embedding dimension: {}", vecs[0].len());
108/// Ok(())
109/// }
110/// ```
111#[async_trait]
112pub trait EmbeddingProvider: Send + Sync {
113 /// Generate embedding vectors for the given list of input strings.
114 ///
115 /// Returns one vector per input in the same order as `inputs`.
116 async fn embed(&self, inputs: Vec<String>) -> Result<Vec<Vec<f32>>>;
117}
118
119// ─── Message / Role ────────────────────────────────────────────────────────
120
121/// A single message in a conversation.
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
123pub struct Message {
124 /// The role of the message sender.
125 pub role: Role,
126 /// The text content of the message.
127 ///
128 /// For assistant messages that contain only tool calls, this may be empty.
129 pub content: String,
130 /// Optional name used to distinguish multiple participants of the same role.
131 #[serde(skip_serializing_if = "Option::is_none")]
132 pub name: Option<String>,
133 /// Tool calls requested by the assistant, if any.
134 ///
135 /// Present on messages with `role = Assistant` when the model wants to
136 /// invoke one or more tools.
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub tool_calls: Option<Vec<ToolCall>>,
139 /// The tool-call ID this message is responding to.
140 ///
141 /// Must be set on messages with `role = Tool` so the provider can
142 /// correlate the result with the originating call.
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub tool_call_id: Option<String>,
145}
146
147/// The role of a message sender.
148#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
149#[serde(rename_all = "lowercase")]
150pub enum Role {
151 /// A high-level instruction that shapes the assistant's behaviour.
152 System,
153 /// A message from the human end of the conversation.
154 User,
155 /// A message generated by the assistant.
156 Assistant,
157 /// A legacy function-call result (OpenAI function-calling v1).
158 #[serde(rename = "function")]
159 Function,
160 /// A tool-call result sent back by the client.
161 Tool,
162}
163
164impl Message {
165 /// Create a system message.
166 pub fn system(content: impl Into<String>) -> Self {
167 Self {
168 role: Role::System,
169 content: content.into(),
170 name: None,
171 tool_calls: None,
172 tool_call_id: None,
173 }
174 }
175
176 /// Create a user message.
177 pub fn user(content: impl Into<String>) -> Self {
178 Self {
179 role: Role::User,
180 content: content.into(),
181 name: None,
182 tool_calls: None,
183 tool_call_id: None,
184 }
185 }
186
187 /// Create an assistant message.
188 pub fn assistant(content: impl Into<String>) -> Self {
189 Self {
190 role: Role::Assistant,
191 content: content.into(),
192 name: None,
193 tool_calls: None,
194 tool_call_id: None,
195 }
196 }
197
198 /// Create a tool-result message.
199 ///
200 /// `tool_call_id` must match the `id` of the [`ToolCall`] being answered.
201 pub fn tool_result(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
202 Self {
203 role: Role::Tool,
204 content: content.into(),
205 name: None,
206 tool_calls: None,
207 tool_call_id: Some(tool_call_id.into()),
208 }
209 }
210}
211
212// ─── Tool calling ──────────────────────────────────────────────────────────
213
214/// A tool invocation requested by the model.
215#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
216pub struct ToolCall {
217 /// Unique identifier for this call, used to correlate the result.
218 pub id: String,
219 /// The type of call — currently always `"function"`.
220 #[serde(rename = "type")]
221 pub call_type: String,
222 /// The function the model wants to call.
223 pub function: ToolCallFunction,
224}
225
226/// The function component of a [`ToolCall`].
227#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
228pub struct ToolCallFunction {
229 /// Name of the function to invoke.
230 pub name: String,
231 /// JSON-encoded arguments string, e.g. `"{\"query\":\"Rust\"}"`.
232 pub arguments: String,
233}
234
235// ─── Request ───────────────────────────────────────────────────────────────
236
237/// A completion request sent to a provider.
238#[derive(Debug, Clone, Serialize, Deserialize, Default)]
239pub struct Request {
240 /// The model identifier, e.g. `"gpt-4o"` or `"claude-3-5-sonnet-20241022"`.
241 pub model: String,
242 /// The conversation history, including any system prompt.
243 pub messages: Vec<Message>,
244 /// Sampling temperature in `[0.0, 2.0]`.
245 #[serde(skip_serializing_if = "Option::is_none")]
246 pub temperature: Option<f32>,
247 /// Maximum tokens to generate.
248 #[serde(skip_serializing_if = "Option::is_none")]
249 pub max_tokens: Option<u32>,
250 /// Nucleus sampling parameter.
251 #[serde(skip_serializing_if = "Option::is_none")]
252 pub top_p: Option<f32>,
253 /// Frequency penalty in `[-2.0, 2.0]` (OpenAI only).
254 #[serde(skip_serializing_if = "Option::is_none")]
255 pub frequency_penalty: Option<f32>,
256 /// Presence penalty in `[-2.0, 2.0]` (OpenAI only).
257 #[serde(skip_serializing_if = "Option::is_none")]
258 pub presence_penalty: Option<f32>,
259 /// Stop sequences.
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub stop: Option<Vec<String>>,
262 /// Whether to stream the response. Providers handle this internally.
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub stream: Option<bool>,
265 /// Structured output format (`json_object` etc.).
266 #[serde(skip_serializing_if = "Option::is_none")]
267 pub response_format: Option<ResponseFormat>,
268 /// Provider-specific extra parameters (e.g. `tools`, `tool_choice`).
269 #[serde(skip_serializing_if = "HashMap::is_empty", default)]
270 pub extra: HashMap<String, serde_json::Value>,
271}
272
273impl Request {
274 /// Create a new empty request.
275 pub fn new() -> Self {
276 Self::default()
277 }
278
279 /// Set the model identifier.
280 pub fn with_model(mut self, model: impl Into<String>) -> Self {
281 self.model = model.into();
282 self
283 }
284
285 /// Set the full message list.
286 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
287 self.messages = messages;
288 self
289 }
290
291 /// Append a single message.
292 pub fn with_message(mut self, message: Message) -> Self {
293 self.messages.push(message);
294 self
295 }
296
297 /// Set the sampling temperature.
298 pub fn with_temperature(mut self, temperature: f32) -> Self {
299 self.temperature = Some(temperature);
300 self
301 }
302
303 /// Set the maximum number of tokens to generate.
304 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
305 self.max_tokens = Some(max_tokens);
306 self
307 }
308
309 /// Set the `top_p` nucleus sampling parameter.
310 pub fn with_top_p(mut self, top_p: f32) -> Self {
311 self.top_p = Some(top_p);
312 self
313 }
314
315 /// Enable JSON mode (structured output).
316 pub fn with_json_mode(mut self) -> Self {
317 self.response_format = Some(ResponseFormat::json_object());
318 self
319 }
320
321 /// Insert a provider-specific extra parameter.
322 pub fn with_extra(
323 mut self,
324 key: impl Into<String>,
325 value: impl Into<serde_json::Value>,
326 ) -> Self {
327 self.extra.insert(key.into(), value.into());
328 self
329 }
330}
331
332// ─── ResponseFormat ────────────────────────────────────────────────────────
333
334/// Structured output format specifier.
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct ResponseFormat {
337 /// The format type — e.g. `"json_object"`.
338 #[serde(rename = "type")]
339 pub format_type: String,
340}
341
342impl ResponseFormat {
343 /// Request JSON object output.
344 pub fn json_object() -> Self {
345 Self {
346 format_type: "json_object".to_string(),
347 }
348 }
349}
350
351// ─── Response / Choice / Usage / Chunk ────────────────────────────────────
352
353/// A completed response from a provider.
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct Response {
356 /// Provider-assigned response identifier.
357 pub id: String,
358 /// Model that generated the response.
359 pub model: String,
360 /// One or more completion choices (usually one).
361 pub choices: Vec<Choice>,
362 /// Token usage statistics, if the provider returned them.
363 pub usage: Option<Usage>,
364 /// Unix timestamp of when the response was created.
365 pub created: Option<u64>,
366}
367
368impl Response {
369 /// Return the text content of the first choice.
370 ///
371 /// Returns an empty string if there are no choices.
372 pub fn content(&self) -> &str {
373 self.choices
374 .first()
375 .map(|c| c.message.content.as_str())
376 .unwrap_or("")
377 }
378
379 /// Return token usage statistics, if available.
380 pub fn usage(&self) -> Option<&Usage> {
381 self.usage.as_ref()
382 }
383
384 /// Return the tool calls from the first choice, if any.
385 pub fn tool_calls(&self) -> Option<&Vec<ToolCall>> {
386 self.choices
387 .first()
388 .and_then(|c| c.message.tool_calls.as_ref())
389 }
390}
391
392/// A single completion choice within a [`Response`].
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct Choice {
395 /// Zero-based index of this choice.
396 pub index: u32,
397 /// The message generated for this choice.
398 pub message: Message,
399 /// Reason the model stopped generating, e.g. `"stop"` or `"tool_calls"`.
400 #[serde(skip_serializing_if = "Option::is_none")]
401 pub finish_reason: Option<String>,
402}
403
404/// Token usage statistics for a request.
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct Usage {
407 /// Number of tokens in the prompt.
408 pub prompt_tokens: u32,
409 /// Number of tokens generated.
410 pub completion_tokens: u32,
411 /// Total tokens consumed (`prompt_tokens + completion_tokens`).
412 pub total_tokens: u32,
413}
414
415impl Usage {
416 /// Calculate the USD cost of this request.
417 ///
418 /// `prompt_price` and `completion_price` are expressed as USD per 1 000 tokens.
419 pub fn calculate_cost(&self, prompt_price: f64, completion_price: f64) -> f64 {
420 let prompt_cost = (self.prompt_tokens as f64 / 1000.0) * prompt_price;
421 let completion_cost = (self.completion_tokens as f64 / 1000.0) * completion_price;
422 prompt_cost + completion_cost
423 }
424}
425
426/// A single chunk in a streaming response.
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct Chunk {
429 /// Provider-assigned response identifier.
430 pub id: String,
431 /// Model that generated this chunk.
432 pub model: String,
433 /// The incremental content delta.
434 pub delta: Delta,
435 /// Set on the final chunk — e.g. `"stop"` or `"tool_calls"`.
436 pub finish_reason: Option<String>,
437}
438
439impl Chunk {
440 /// Return the incremental text content of this chunk.
441 pub fn content(&self) -> &str {
442 &self.delta.content
443 }
444
445 /// Return `true` if this is the terminal chunk of the stream.
446 pub fn is_finished(&self) -> bool {
447 self.finish_reason.is_some()
448 }
449}
450
451/// The incremental content delta inside a [`Chunk`].
452#[derive(Debug, Clone, Serialize, Deserialize, Default)]
453pub struct Delta {
454 /// Role of the speaker, present only in the first chunk of a response.
455 pub role: Option<Role>,
456 /// Incremental text content generated since the previous chunk.
457 pub content: String,
458}
459
460// ─── ProviderConfig ────────────────────────────────────────────────────────
461
462/// Configuration shared by all provider clients.
463#[derive(Debug, Clone)]
464pub struct ProviderConfig {
465 /// API key used to authenticate requests.
466 pub api_key: String,
467 /// Override for the provider's default base URL.
468 pub base_url: String,
469 /// Request timeout in seconds.
470 pub timeout_seconds: u64,
471 /// Maximum number of automatic retries on transient errors.
472 pub max_retries: u32,
473}
474
475impl ProviderConfig {
476 /// Create a minimal config with only an API key.
477 pub fn new(api_key: impl Into<String>) -> Self {
478 Self {
479 api_key: api_key.into(),
480 base_url: String::new(),
481 timeout_seconds: 60,
482 max_retries: 3,
483 }
484 }
485
486 /// Override the default base URL (useful for proxies or local servers).
487 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
488 self.base_url = url.into();
489 self
490 }
491
492 /// Set the request timeout in seconds.
493 pub fn with_timeout(mut self, seconds: u64) -> Self {
494 self.timeout_seconds = seconds;
495 self
496 }
497
498 /// Set the maximum number of automatic retries.
499 pub fn with_max_retries(mut self, retries: u32) -> Self {
500 self.max_retries = retries;
501 self
502 }
503}