brainwires_core/
provider.rs1use anyhow::Result;
2use async_trait::async_trait;
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6use crate::message::{ChatResponse, Message, StreamChunk};
7use crate::tool::Tool;
8
9#[async_trait]
11pub trait Provider: Send + Sync {
12 fn name(&self) -> &str;
14
15 fn max_output_tokens(&self) -> Option<u32> {
18 None }
20
21 async fn chat(
23 &self,
24 messages: &[Message],
25 tools: Option<&[Tool]>,
26 options: &ChatOptions,
27 ) -> Result<ChatResponse>;
28
29 fn stream_chat<'a>(
31 &'a self,
32 messages: &'a [Message],
33 tools: Option<&'a [Tool]>,
34 options: &'a ChatOptions,
35 ) -> BoxStream<'a, Result<StreamChunk>>;
36}
37
38#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
46#[serde(rename_all = "snake_case")]
47pub enum CacheStrategy {
48 Off,
50 SystemOnly,
52 #[default]
54 SystemAndTools,
55 SystemAndTailTurn {
58 threshold_tokens: u32,
61 },
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ChatOptions {
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub temperature: Option<f32>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub max_tokens: Option<u32>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub top_p: Option<f32>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub stop: Option<Vec<String>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub system: Option<String>,
82 #[serde(skip_serializing_if = "Option::is_none")]
87 pub model: Option<String>,
88 #[serde(default)]
90 pub cache_strategy: CacheStrategy,
91}
92
93impl Default for ChatOptions {
94 fn default() -> Self {
95 Self {
96 temperature: Some(0.7),
97 max_tokens: Some(4096),
98 top_p: None,
99 stop: None,
100 system: None,
101 model: None,
102 cache_strategy: CacheStrategy::default(),
103 }
104 }
105}
106
107impl ChatOptions {
108 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn temperature(mut self, temperature: f32) -> Self {
115 self.temperature = Some(temperature);
116 self
117 }
118
119 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
121 self.max_tokens = Some(max_tokens);
122 self
123 }
124
125 pub fn system<S: Into<String>>(mut self, system: S) -> Self {
127 self.system = Some(system.into());
128 self
129 }
130
131 pub fn top_p(mut self, top_p: f32) -> Self {
133 self.top_p = Some(top_p);
134 self
135 }
136
137 pub fn model<S: Into<String>>(mut self, model: S) -> Self {
139 self.model = Some(model.into());
140 self
141 }
142
143 pub fn cache_strategy(mut self, strategy: CacheStrategy) -> Self {
145 self.cache_strategy = strategy;
146 self
147 }
148
149 pub fn deterministic(max_tokens: u32) -> Self {
151 Self {
152 temperature: Some(0.0),
153 max_tokens: Some(max_tokens),
154 ..Default::default()
155 }
156 }
157
158 pub fn factual(max_tokens: u32) -> Self {
160 Self {
161 temperature: Some(0.1),
162 max_tokens: Some(max_tokens),
163 top_p: Some(0.9),
164 ..Default::default()
165 }
166 }
167
168 pub fn creative(max_tokens: u32) -> Self {
170 Self {
171 temperature: Some(0.3),
172 max_tokens: Some(max_tokens),
173 ..Default::default()
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_chat_options_default() {
184 let opts = ChatOptions::default();
185 assert_eq!(opts.temperature, Some(0.7));
186 assert_eq!(opts.max_tokens, Some(4096));
187 }
188
189 #[test]
190 fn test_chat_options_builder() {
191 let opts = ChatOptions::new()
192 .temperature(0.5)
193 .max_tokens(2048)
194 .system("Test");
195 assert_eq!(opts.temperature, Some(0.5));
196 assert_eq!(opts.max_tokens, Some(2048));
197 assert_eq!(opts.system, Some("Test".to_string()));
198 }
199
200 #[test]
201 fn test_chat_options_deterministic() {
202 let opts = ChatOptions::deterministic(50);
203 assert_eq!(opts.temperature, Some(0.0));
204 assert_eq!(opts.max_tokens, Some(50));
205 }
206
207 #[test]
208 fn test_chat_options_factual() {
209 let opts = ChatOptions::factual(200);
210 assert_eq!(opts.temperature, Some(0.1));
211 assert_eq!(opts.max_tokens, Some(200));
212 assert_eq!(opts.top_p, Some(0.9));
213 }
214
215 #[test]
216 fn test_chat_options_creative() {
217 let opts = ChatOptions::creative(400);
218 assert_eq!(opts.temperature, Some(0.3));
219 assert_eq!(opts.max_tokens, Some(400));
220 }
221}