Skip to main content

brainwires_core/
provider.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6use crate::message::{ChatResponse, Message, StreamChunk};
7use crate::tool::Tool;
8
9/// Base provider trait for AI providers
10#[async_trait]
11pub trait Provider: Send + Sync {
12    /// Get the provider name
13    fn name(&self) -> &str;
14
15    /// Get the model's maximum output tokens (for setting appropriate limits)
16    /// Returns None if the model doesn't have a specific limit
17    fn max_output_tokens(&self) -> Option<u32> {
18        None // Default implementation - providers can override
19    }
20
21    /// Chat completion (non-streaming)
22    async fn chat(
23        &self,
24        messages: &[Message],
25        tools: Option<&[Tool]>,
26        options: &ChatOptions,
27    ) -> Result<ChatResponse>;
28
29    /// Chat completion (streaming)
30    fn stream_chat<'a>(
31        &'a self,
32        messages: &'a [Message],
33        tools: Option<&'a [Tool]>,
34        options: &'a ChatOptions,
35    ) -> BoxStream<'a, Result<StreamChunk>>;
36}
37
38/// Prompt-cache strategy for providers that support explicit caching
39/// (Anthropic Messages API today; a no-op elsewhere).
40///
41/// Controls which parts of a request receive `cache_control` breakpoints.
42/// Caching reuses cached prompt bytes across turns for a 50–90% input-token
43/// discount on subsequent calls, at the cost of a one-time "creation" charge
44/// on first population.
45#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
46#[serde(rename_all = "snake_case")]
47pub enum CacheStrategy {
48    /// No cache breakpoints. Fresh compute on every call.
49    Off,
50    /// Cache only the system prompt.
51    SystemOnly,
52    /// Cache the system prompt and tool definitions (the default).
53    #[default]
54    SystemAndTools,
55    /// Cache system + tools + the tail of the conversation once the message
56    /// history reaches the given approximate token threshold.
57    SystemAndTailTurn {
58        /// Minimum conversation size (approximate tokens) before the tail
59        /// breakpoint is emitted. Avoids wasting a cache slot on short chats.
60        threshold_tokens: u32,
61    },
62}
63
64/// Chat completion options
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ChatOptions {
67    /// Temperature (0.0 - 1.0)
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub temperature: Option<f32>,
70    /// Maximum tokens to generate
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub max_tokens: Option<u32>,
73    /// Top-p sampling
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub top_p: Option<f32>,
76    /// Stop sequences
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub stop: Option<Vec<String>>,
79    /// System prompt
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub system: Option<String>,
82    /// Per-request model override.
83    ///
84    /// When `Some`, providers MUST use this model name instead of their default.
85    /// This enables per-session model switching without replacing the provider.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub model: Option<String>,
88    /// Prompt-cache strategy. Ignored by providers without prompt caching.
89    #[serde(default)]
90    pub cache_strategy: CacheStrategy,
91}
92
93impl Default for ChatOptions {
94    fn default() -> Self {
95        Self {
96            temperature: Some(0.7),
97            max_tokens: Some(4096),
98            top_p: None,
99            stop: None,
100            system: None,
101            model: None,
102            cache_strategy: CacheStrategy::default(),
103        }
104    }
105}
106
107impl ChatOptions {
108    /// Create new chat options with defaults
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    /// Set temperature
114    pub fn temperature(mut self, temperature: f32) -> Self {
115        self.temperature = Some(temperature);
116        self
117    }
118
119    /// Set max tokens
120    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
121        self.max_tokens = Some(max_tokens);
122        self
123    }
124
125    /// Set system prompt
126    pub fn system<S: Into<String>>(mut self, system: S) -> Self {
127        self.system = Some(system.into());
128        self
129    }
130
131    /// Set top-p sampling
132    pub fn top_p(mut self, top_p: f32) -> Self {
133        self.top_p = Some(top_p);
134        self
135    }
136
137    /// Override the model for this request.
138    pub fn model<S: Into<String>>(mut self, model: S) -> Self {
139        self.model = Some(model.into());
140        self
141    }
142
143    /// Set the prompt-cache strategy.
144    pub fn cache_strategy(mut self, strategy: CacheStrategy) -> Self {
145        self.cache_strategy = strategy;
146        self
147    }
148
149    /// Deterministic classification/routing (temp=0, few tokens)
150    pub fn deterministic(max_tokens: u32) -> Self {
151        Self {
152            temperature: Some(0.0),
153            max_tokens: Some(max_tokens),
154            ..Default::default()
155        }
156    }
157
158    /// Low-temperature factual generation
159    pub fn factual(max_tokens: u32) -> Self {
160        Self {
161            temperature: Some(0.1),
162            max_tokens: Some(max_tokens),
163            top_p: Some(0.9),
164            ..Default::default()
165        }
166    }
167
168    /// Creative generation with moderate temperature
169    pub fn creative(max_tokens: u32) -> Self {
170        Self {
171            temperature: Some(0.3),
172            max_tokens: Some(max_tokens),
173            ..Default::default()
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_chat_options_default() {
184        let opts = ChatOptions::default();
185        assert_eq!(opts.temperature, Some(0.7));
186        assert_eq!(opts.max_tokens, Some(4096));
187    }
188
189    #[test]
190    fn test_chat_options_builder() {
191        let opts = ChatOptions::new()
192            .temperature(0.5)
193            .max_tokens(2048)
194            .system("Test");
195        assert_eq!(opts.temperature, Some(0.5));
196        assert_eq!(opts.max_tokens, Some(2048));
197        assert_eq!(opts.system, Some("Test".to_string()));
198    }
199
200    #[test]
201    fn test_chat_options_deterministic() {
202        let opts = ChatOptions::deterministic(50);
203        assert_eq!(opts.temperature, Some(0.0));
204        assert_eq!(opts.max_tokens, Some(50));
205    }
206
207    #[test]
208    fn test_chat_options_factual() {
209        let opts = ChatOptions::factual(200);
210        assert_eq!(opts.temperature, Some(0.1));
211        assert_eq!(opts.max_tokens, Some(200));
212        assert_eq!(opts.top_p, Some(0.9));
213    }
214
215    #[test]
216    fn test_chat_options_creative() {
217        let opts = ChatOptions::creative(400);
218        assert_eq!(opts.temperature, Some(0.3));
219        assert_eq!(opts.max_tokens, Some(400));
220    }
221}