brainwires_core/
provider.rs1use anyhow::Result;
2use async_trait::async_trait;
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6use crate::message::{ChatResponse, Message, StreamChunk};
7use crate::tool::Tool;
8
9#[async_trait]
11pub trait Provider: Send + Sync {
12 fn name(&self) -> &str;
14
15 fn max_output_tokens(&self) -> Option<u32> {
18 None }
20
21 async fn chat(
23 &self,
24 messages: &[Message],
25 tools: Option<&[Tool]>,
26 options: &ChatOptions,
27 ) -> Result<ChatResponse>;
28
29 fn stream_chat<'a>(
31 &'a self,
32 messages: &'a [Message],
33 tools: Option<&'a [Tool]>,
34 options: &'a ChatOptions,
35 ) -> BoxStream<'a, Result<StreamChunk>>;
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ChatOptions {
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub temperature: Option<f32>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub max_tokens: Option<u32>,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub top_p: Option<f32>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub stop: Option<Vec<String>>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub system: Option<String>,
56}
57
58impl Default for ChatOptions {
59 fn default() -> Self {
60 Self {
61 temperature: Some(0.7),
62 max_tokens: Some(4096),
63 top_p: None,
64 stop: None,
65 system: None,
66 }
67 }
68}
69
70impl ChatOptions {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn temperature(mut self, temperature: f32) -> Self {
78 self.temperature = Some(temperature);
79 self
80 }
81
82 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
84 self.max_tokens = Some(max_tokens);
85 self
86 }
87
88 pub fn system<S: Into<String>>(mut self, system: S) -> Self {
90 self.system = Some(system.into());
91 self
92 }
93
94 pub fn top_p(mut self, top_p: f32) -> Self {
96 self.top_p = Some(top_p);
97 self
98 }
99
100 pub fn deterministic(max_tokens: u32) -> Self {
102 Self {
103 temperature: Some(0.0),
104 max_tokens: Some(max_tokens),
105 ..Default::default()
106 }
107 }
108
109 pub fn factual(max_tokens: u32) -> Self {
111 Self {
112 temperature: Some(0.1),
113 max_tokens: Some(max_tokens),
114 top_p: Some(0.9),
115 ..Default::default()
116 }
117 }
118
119 pub fn creative(max_tokens: u32) -> Self {
121 Self {
122 temperature: Some(0.3),
123 max_tokens: Some(max_tokens),
124 ..Default::default()
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_chat_options_default() {
135 let opts = ChatOptions::default();
136 assert_eq!(opts.temperature, Some(0.7));
137 assert_eq!(opts.max_tokens, Some(4096));
138 }
139
140 #[test]
141 fn test_chat_options_builder() {
142 let opts = ChatOptions::new()
143 .temperature(0.5)
144 .max_tokens(2048)
145 .system("Test");
146 assert_eq!(opts.temperature, Some(0.5));
147 assert_eq!(opts.max_tokens, Some(2048));
148 assert_eq!(opts.system, Some("Test".to_string()));
149 }
150
151 #[test]
152 fn test_chat_options_deterministic() {
153 let opts = ChatOptions::deterministic(50);
154 assert_eq!(opts.temperature, Some(0.0));
155 assert_eq!(opts.max_tokens, Some(50));
156 }
157
158 #[test]
159 fn test_chat_options_factual() {
160 let opts = ChatOptions::factual(200);
161 assert_eq!(opts.temperature, Some(0.1));
162 assert_eq!(opts.max_tokens, Some(200));
163 assert_eq!(opts.top_p, Some(0.9));
164 }
165
166 #[test]
167 fn test_chat_options_creative() {
168 let opts = ChatOptions::creative(400);
169 assert_eq!(opts.temperature, Some(0.3));
170 assert_eq!(opts.max_tokens, Some(400));
171 }
172}