1#![doc = include_str!("../README.md")]
2
3pub mod llms;
4mod sse;
5
6pub const DEFAULT_MAX_TOKENS: usize = 4096;
7pub const DEFAULT_TEMPERATURE: f32 = 1.0;
8
9#[derive(Debug, thiserror::Error)]
36pub enum PromptError {
37 #[error("failed to build request to model")]
38 RequestError(#[from] hyper::http::Error),
39 #[error("failed to transcode prompt or response")]
40 TranscodingError(#[from] serde_json::Error),
41}
42
43pub struct ToolParameter<'a> {
44 pub name: &'a str,
45 pub description: &'a str,
46 pub parameters: &'a serde_json::Value,
47}
48
49#[derive(Debug, Clone, PartialEq)]
51pub struct ToolParameters {
52 inner: schemars::schema::Schema,
53}
54
55impl ToolParameters {
56 pub fn new<S: schemars::JsonSchema>() -> Self {
57 let mut generator = schemars::gen::SchemaGenerator::default();
58 Self {
59 inner: <S as schemars::JsonSchema>::json_schema(&mut generator),
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
66pub struct Tool {
67 pub name: String,
68 pub description: String,
69 pub parameters: ToolParameters,
70}
71
72#[derive(Debug, Copy, Clone, PartialEq)]
76pub enum ReasoningEffort {
77 Low,
78 Medium,
79 High,
80}
81
82impl ReasoningEffort {
83 fn max_tokens(&self) -> usize {
84 match self {
85 Self::Low => 1024,
86 Self::Medium => 2048,
87 Self::High => 4096,
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
93pub struct PromptOptions {
94 pub max_tokens: usize,
95 pub temperature: f32,
96 pub system_prompt: Option<String>,
97 pub stopping_sequences: Vec<String>,
98 pub tools: Vec<Tool>,
99 pub reasoning: Option<ReasoningEffort>,
100}
101
102impl Default for PromptOptions {
103 fn default() -> Self {
104 Self {
105 max_tokens: DEFAULT_MAX_TOKENS,
106 temperature: DEFAULT_TEMPERATURE,
107 system_prompt: None,
108 stopping_sequences: vec![],
109 tools: vec![],
110 reasoning: None,
111 }
112 }
113}
114
115impl PromptOptions {
116 pub fn set_max_tokens(&mut self, max_tokens: usize) -> &mut Self {
117 self.max_tokens = max_tokens;
118 self
119 }
120 pub fn set_temperature(&mut self, temperature: f32) -> &mut Self {
121 self.temperature = temperature;
122 self
123 }
124 pub fn set_system_prompt(&mut self, system_prompt: String) -> &mut Self {
125 self.system_prompt = Some(system_prompt);
126 self
127 }
128 pub fn set_stopping_sequences(&mut self, stopping_sequences: Vec<String>) -> &mut Self {
129 self.stopping_sequences = stopping_sequences;
130 self
131 }
132
133 pub fn max_tokens(&self) -> usize {
134 self.max_tokens
135 }
136 pub fn temperature(&self) -> f32 {
137 self.temperature
138 }
139 pub fn system_prompt(&self) -> Option<&str> {
140 self.system_prompt.as_deref()
141 }
142 pub fn stopping_sequences(&self) -> &[String] {
143 &self.stopping_sequences[..]
144 }
145}
146
147pub struct SerializedJson {
149 raw: serde_json::Value,
150 serialized: String,
151}
152
153impl SerializedJson {
154 pub fn try_new(value: serde_json::Value) -> serde_json::Result<Self> {
156 Ok(Self {
157 serialized: serde_json::to_string(&value)?,
158 raw: value,
159 })
160 }
161}
162
163pub enum Message {
164 User(String),
165 Assistant(String),
166 ToolRequest {
167 id: String,
168 name: String,
169 arguments: SerializedJson,
170 },
171 ToolResponse {
172 content: String,
173 id: String,
174 },
175}
176
177pub trait LLM {
179 type TokenStream: futures::Stream<Item = Result<Chunk, TokenError>> + Send;
180
181 fn prompt(
184 &self,
185 messages: &[Message],
186 options: &PromptOptions,
187 ) -> Result<Self::TokenStream, PromptError>;
188}
189
190mod sealed {
191 pub trait TokenStreamExtSealed {}
192 impl<T> TokenStreamExtSealed for T where
193 T: futures::Stream<Item = Result<super::Chunk, super::TokenError>> + Send
194 {
195 }
196}
197pub trait TokenStreamExt: sealed::TokenStreamExtSealed {
199 fn all_tokens(self)
202 -> impl std::future::Future<Output = Result<Vec<Chunk>, TokenError>> + Send;
203}
204impl<T> TokenStreamExt for T
205where
206 T: sealed::TokenStreamExtSealed + futures::Stream<Item = Result<Chunk, TokenError>> + Send,
207{
208 async fn all_tokens(self) -> Result<Vec<Chunk>, TokenError> {
209 use futures::StreamExt;
210 let mut stream = Box::pin(self);
211
212 let mut acc = vec![];
213
214 while let Some(token) = stream.next().await {
215 tracing::debug!("received token in all_tokens: {:?}", token);
216 if let Some(last_acc) = acc.last_mut() {
217 match (last_acc, token?) {
218 (Chunk::Token(lhs), Chunk::Token(rhs)) => lhs.push_str(&rhs),
219 (Chunk::Thinking(lhs), Chunk::Thinking(rhs)) => lhs.push_str(&rhs),
220 (Chunk::ToolCall(lhs), Chunk::ToolCall(rhs))
221 if lhs.id.as_ref().is_none_or(|lhs_id| {
222 rhs.id.as_ref().is_none_or(|rhs_id| lhs_id == rhs_id)
223 }) =>
224 {
225 lhs.id = lhs.id.take().or(rhs.id);
226 lhs.name = lhs.name.take().or(rhs.name);
227 lhs.arguments.push_str(&rhs.arguments);
228 }
229 (_, token) => acc.push(token),
230 }
231 } else {
232 acc.push(token?);
233 };
234 }
235
236 Ok(acc)
237 }
238}
239
240#[derive(Debug, Clone)]
241pub struct ToolCallChunk {
242 pub id: Option<String>,
243 pub name: Option<String>,
244 pub arguments: String,
245}
246
247#[derive(Debug, Clone)]
248pub enum Chunk {
249 Token(String),
250 Thinking(String),
251 ToolCall(ToolCallChunk),
252}
253
254impl Chunk {
255 pub fn try_into_message(self) -> Option<Message> {
256 match self {
257 Chunk::Token(content) => Some(Message::Assistant(content)),
258 Chunk::Thinking(_) => None,
259 Chunk::ToolCall(tool_call_chunk) => Some(Message::ToolRequest {
260 id: tool_call_chunk.id?,
261 name: tool_call_chunk.name?,
262 arguments: SerializedJson::try_new(
263 serde_json::from_str::<serde_json::Value>(&tool_call_chunk.arguments).ok()?,
264 )
265 .ok()?,
266 }),
267 }
268 }
269}
270
271#[derive(Debug, thiserror::Error)]
272pub enum TokenError {
273 #[error("the connection was lost")]
274 ConnectionLost(#[from] sse::Error),
275 #[error("the server responded with an unknown event type `{0}`")]
276 UnknownEventType(String),
277 #[error("the server responded with unexpected data: {message}")]
278 MalformedResponse {
279 message: &'static str,
280 value: serde_json::Value,
281 },
282}
283
284pub use schemars::JsonSchema;
285pub use serde;
286pub use serde_json;
287pub use sse::Error as SseError;
288
289trait JsonExt {
290 fn take_str(&mut self) -> Option<String>;
291}
292
293impl JsonExt for serde_json::Value {
294 fn take_str(&mut self) -> Option<String> {
295 if let serde_json::Value::String(s) = self.take() {
296 Some(s)
297 } else {
298 None
299 }
300 }
301}