Skip to main content

brainwires_core/
provider.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6use crate::message::{ChatResponse, Message, StreamChunk};
7use crate::tool::Tool;
8
9/// Base provider trait for AI providers
10#[async_trait]
11pub trait Provider: Send + Sync {
12    /// Get the provider name
13    fn name(&self) -> &str;
14
15    /// Get the model's maximum output tokens (for setting appropriate limits)
16    /// Returns None if the model doesn't have a specific limit
17    fn max_output_tokens(&self) -> Option<u32> {
18        None // Default implementation - providers can override
19    }
20
21    /// Chat completion (non-streaming)
22    async fn chat(
23        &self,
24        messages: &[Message],
25        tools: Option<&[Tool]>,
26        options: &ChatOptions,
27    ) -> Result<ChatResponse>;
28
29    /// Chat completion (streaming)
30    fn stream_chat<'a>(
31        &'a self,
32        messages: &'a [Message],
33        tools: Option<&'a [Tool]>,
34        options: &'a ChatOptions,
35    ) -> BoxStream<'a, Result<StreamChunk>>;
36}
37
38/// Chat completion options
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ChatOptions {
41    /// Temperature (0.0 - 1.0)
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub temperature: Option<f32>,
44    /// Maximum tokens to generate
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub max_tokens: Option<u32>,
47    /// Top-p sampling
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub top_p: Option<f32>,
50    /// Stop sequences
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub stop: Option<Vec<String>>,
53    /// System prompt
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub system: Option<String>,
56}
57
58impl Default for ChatOptions {
59    fn default() -> Self {
60        Self {
61            temperature: Some(0.7),
62            max_tokens: Some(4096),
63            top_p: None,
64            stop: None,
65            system: None,
66        }
67    }
68}
69
70impl ChatOptions {
71    /// Create new chat options with defaults
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Set temperature
77    pub fn temperature(mut self, temperature: f32) -> Self {
78        self.temperature = Some(temperature);
79        self
80    }
81
82    /// Set max tokens
83    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
84        self.max_tokens = Some(max_tokens);
85        self
86    }
87
88    /// Set system prompt
89    pub fn system<S: Into<String>>(mut self, system: S) -> Self {
90        self.system = Some(system.into());
91        self
92    }
93
94    /// Set top-p sampling
95    pub fn top_p(mut self, top_p: f32) -> Self {
96        self.top_p = Some(top_p);
97        self
98    }
99
100    /// Deterministic classification/routing (temp=0, few tokens)
101    pub fn deterministic(max_tokens: u32) -> Self {
102        Self {
103            temperature: Some(0.0),
104            max_tokens: Some(max_tokens),
105            ..Default::default()
106        }
107    }
108
109    /// Low-temperature factual generation
110    pub fn factual(max_tokens: u32) -> Self {
111        Self {
112            temperature: Some(0.1),
113            max_tokens: Some(max_tokens),
114            top_p: Some(0.9),
115            ..Default::default()
116        }
117    }
118
119    /// Creative generation with moderate temperature
120    pub fn creative(max_tokens: u32) -> Self {
121        Self {
122            temperature: Some(0.3),
123            max_tokens: Some(max_tokens),
124            ..Default::default()
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_chat_options_default() {
135        let opts = ChatOptions::default();
136        assert_eq!(opts.temperature, Some(0.7));
137        assert_eq!(opts.max_tokens, Some(4096));
138    }
139
140    #[test]
141    fn test_chat_options_builder() {
142        let opts = ChatOptions::new()
143            .temperature(0.5)
144            .max_tokens(2048)
145            .system("Test");
146        assert_eq!(opts.temperature, Some(0.5));
147        assert_eq!(opts.max_tokens, Some(2048));
148        assert_eq!(opts.system, Some("Test".to_string()));
149    }
150
151    #[test]
152    fn test_chat_options_deterministic() {
153        let opts = ChatOptions::deterministic(50);
154        assert_eq!(opts.temperature, Some(0.0));
155        assert_eq!(opts.max_tokens, Some(50));
156    }
157
158    #[test]
159    fn test_chat_options_factual() {
160        let opts = ChatOptions::factual(200);
161        assert_eq!(opts.temperature, Some(0.1));
162        assert_eq!(opts.max_tokens, Some(200));
163        assert_eq!(opts.top_p, Some(0.9));
164    }
165
166    #[test]
167    fn test_chat_options_creative() {
168        let opts = ChatOptions::creative(400);
169        assert_eq!(opts.temperature, Some(0.3));
170        assert_eq!(opts.max_tokens, Some(400));
171    }
172}