lmql/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod llms;
4mod sse;
5
6pub const DEFAULT_MAX_TOKENS: usize = 4096;
7pub const DEFAULT_TEMPERATURE: f32 = 1.0;
8
9//pub use lmql_macros::*;
10//#[macro_export]
11/*macro_rules! prompt {
12    ($model:expr => $(
13        user: $prompt:literal;
14        assistant: $response:literal $(where $($out:ident : $out_ty:ty),* $(,)?)?
15    );* $(;)?) => {async {
16        let res = $crate::prompt_inner!($model => $(
17            user: $prompt;
18            assistant: $response $(where $($out : $out_ty),*)*;
19        )*).await;
20
21        // Formatting in IDE.
22        if let Ok(res) = res {
23            if false {
24                $(
25                    let _ = format!($prompt);
26                    let _ = format!($response, $($($out = res.$out),*)* );
27                )*
28            }
29        }
30
31        res
32    }};
33}*/
34
35#[derive(Debug, thiserror::Error)]
36pub enum PromptError {
37    #[error("failed to build request to model")]
38    RequestError(#[from] hyper::http::Error),
39    #[error("failed to transcode prompt or response")]
40    TranscodingError(#[from] serde_json::Error),
41}
42
43pub struct ToolParameter<'a> {
44    pub name: &'a str,
45    pub description: &'a str,
46    pub parameters: &'a serde_json::Value,
47}
48
49/// The parameters of a tool available to an LLM.
50#[derive(Debug, Clone, PartialEq)]
51pub struct ToolParameters {
52    inner: schemars::schema::Schema,
53}
54
55impl ToolParameters {
56    pub fn new<S: schemars::JsonSchema>() -> Self {
57        let mut generator = schemars::gen::SchemaGenerator::default();
58        Self {
59            inner: <S as schemars::JsonSchema>::json_schema(&mut generator),
60        }
61    }
62}
63
64/// A tool accessible to an LLM.
65#[derive(Debug, Clone, PartialEq)]
66pub struct Tool {
67    pub name: String,
68    pub description: String,
69    pub parameters: ToolParameters,
70}
71
72/// The effort to put into reasoning.
73/// For non-reasoning models, this is ignored.
74/// For non-open-ai models, this corresponds to the maximum number of tokens to use for reasoning.
75#[derive(Debug, Copy, Clone, PartialEq)]
76pub enum ReasoningEffort {
77    Low,
78    Medium,
79    High,
80}
81
82impl ReasoningEffort {
83    fn max_tokens(&self) -> usize {
84        match self {
85            Self::Low => 1024,
86            Self::Medium => 2048,
87            Self::High => 4096,
88        }
89    }
90}
91
92#[derive(Debug, Clone, PartialEq)]
93pub struct PromptOptions {
94    pub max_tokens: usize,
95    pub temperature: f32,
96    pub system_prompt: Option<String>,
97    pub stopping_sequences: Vec<String>,
98    pub tools: Vec<Tool>,
99    pub reasoning: Option<ReasoningEffort>,
100}
101
102impl Default for PromptOptions {
103    fn default() -> Self {
104        Self {
105            max_tokens: DEFAULT_MAX_TOKENS,
106            temperature: DEFAULT_TEMPERATURE,
107            system_prompt: None,
108            stopping_sequences: vec![],
109            tools: vec![],
110            reasoning: None,
111        }
112    }
113}
114
115impl PromptOptions {
116    pub fn set_max_tokens(&mut self, max_tokens: usize) -> &mut Self {
117        self.max_tokens = max_tokens;
118        self
119    }
120    pub fn set_temperature(&mut self, temperature: f32) -> &mut Self {
121        self.temperature = temperature;
122        self
123    }
124    pub fn set_system_prompt(&mut self, system_prompt: String) -> &mut Self {
125        self.system_prompt = Some(system_prompt);
126        self
127    }
128    pub fn set_stopping_sequences(&mut self, stopping_sequences: Vec<String>) -> &mut Self {
129        self.stopping_sequences = stopping_sequences;
130        self
131    }
132
133    pub fn max_tokens(&self) -> usize {
134        self.max_tokens
135    }
136    pub fn temperature(&self) -> f32 {
137        self.temperature
138    }
139    pub fn system_prompt(&self) -> Option<&str> {
140        self.system_prompt.as_deref()
141    }
142    pub fn stopping_sequences(&self) -> &[String] {
143        &self.stopping_sequences[..]
144    }
145}
146
147/// Some `serde_json::Value` that has been serialized to a string.
148pub struct SerializedJson {
149    raw: serde_json::Value,
150    serialized: String,
151}
152
153impl SerializedJson {
154    /// Serialization can fail if T's implementation of Serialize decides to fail, or if T contains a map with non-string keys.
155    pub fn try_new(value: serde_json::Value) -> serde_json::Result<Self> {
156        Ok(Self {
157            serialized: serde_json::to_string(&value)?,
158            raw: value,
159        })
160    }
161}
162
163pub enum Message {
164    User(String),
165    Assistant(String),
166    ToolRequest {
167        id: String,
168        name: String,
169        arguments: SerializedJson,
170    },
171    ToolResponse {
172        content: String,
173        id: String,
174    },
175}
176
177/// Some hook into an LLM, which can be used to generate text.
178pub trait LLM {
179    type TokenStream: futures::Stream<Item = Result<Chunk, TokenError>> + Send;
180
181    /// Generates a response to the given prompt. The prompt is a list of strings, where each
182    /// is either the user or the assistant, starting with the user and alternating.
183    fn prompt(
184        &self,
185        messages: &[Message],
186        options: &PromptOptions,
187    ) -> Result<Self::TokenStream, PromptError>;
188}
189
190mod sealed {
191    pub trait TokenStreamExtSealed {}
192    impl<T> TokenStreamExtSealed for T where
193        T: futures::Stream<Item = Result<super::Chunk, super::TokenError>> + Send
194    {
195    }
196}
197/// Utility methods for token stream sources.
198pub trait TokenStreamExt: sealed::TokenStreamExtSealed {
199    /// Converts the stream of tokens into a single set of tokens future, collapsing adjacent like tokens.
200    /// This is useful for when you don't want to filter the tokens as they arrive.
201    fn all_tokens(self)
202        -> impl std::future::Future<Output = Result<Vec<Chunk>, TokenError>> + Send;
203}
204impl<T> TokenStreamExt for T
205where
206    T: sealed::TokenStreamExtSealed + futures::Stream<Item = Result<Chunk, TokenError>> + Send,
207{
208    async fn all_tokens(self) -> Result<Vec<Chunk>, TokenError> {
209        use futures::StreamExt;
210        let mut stream = Box::pin(self);
211
212        let mut acc = vec![];
213
214        while let Some(token) = stream.next().await {
215            tracing::debug!("received token in all_tokens: {:?}", token);
216            if let Some(last_acc) = acc.last_mut() {
217                match (last_acc, token?) {
218                    (Chunk::Token(lhs), Chunk::Token(rhs)) => lhs.push_str(&rhs),
219                    (Chunk::Thinking(lhs), Chunk::Thinking(rhs)) => lhs.push_str(&rhs),
220                    (Chunk::ToolCall(lhs), Chunk::ToolCall(rhs))
221                        if lhs.id.as_ref().is_none_or(|lhs_id| {
222                            rhs.id.as_ref().is_none_or(|rhs_id| lhs_id == rhs_id)
223                        }) =>
224                    {
225                        lhs.id = lhs.id.take().or(rhs.id);
226                        lhs.name = lhs.name.take().or(rhs.name);
227                        lhs.arguments.push_str(&rhs.arguments);
228                    }
229                    (_, token) => acc.push(token),
230                }
231            } else {
232                acc.push(token?);
233            };
234        }
235
236        Ok(acc)
237    }
238}
239
240#[derive(Debug, Clone)]
241pub struct ToolCallChunk {
242    pub id: Option<String>,
243    pub name: Option<String>,
244    pub arguments: String,
245}
246
247#[derive(Debug, Clone)]
248pub enum Chunk {
249    Token(String),
250    Thinking(String),
251    ToolCall(ToolCallChunk),
252}
253
254impl Chunk {
255    pub fn try_into_message(self) -> Option<Message> {
256        match self {
257            Chunk::Token(content) => Some(Message::Assistant(content)),
258            Chunk::Thinking(_) => None,
259            Chunk::ToolCall(tool_call_chunk) => Some(Message::ToolRequest {
260                id: tool_call_chunk.id?,
261                name: tool_call_chunk.name?,
262                arguments: SerializedJson::try_new(
263                    serde_json::from_str::<serde_json::Value>(&tool_call_chunk.arguments).ok()?,
264                )
265                .ok()?,
266            }),
267        }
268    }
269}
270
271#[derive(Debug, thiserror::Error)]
272pub enum TokenError {
273    #[error("the connection was lost")]
274    ConnectionLost(#[from] sse::Error),
275    #[error("the server responded with an unknown event type `{0}`")]
276    UnknownEventType(String),
277    #[error("the server responded with unexpected data: {message}")]
278    MalformedResponse {
279        message: &'static str,
280        value: serde_json::Value,
281    },
282}
283
284pub use schemars::JsonSchema;
285pub use serde;
286pub use serde_json;
287pub use sse::Error as SseError;
288
289trait JsonExt {
290    fn take_str(&mut self) -> Option<String>;
291}
292
293impl JsonExt for serde_json::Value {
294    fn take_str(&mut self) -> Option<String> {
295        if let serde_json::Value::String(s) = self.take() {
296            Some(s)
297        } else {
298            None
299        }
300    }
301}