#![doc = include_str!("../README.md")]
pub mod llms;
mod sse;
pub const DEFAULT_MAX_TOKENS: usize = 4096;
pub const DEFAULT_TEMPERATURE: f32 = 1.0;
#[derive(Debug, thiserror::Error)]
pub enum PromptError {
#[error("failed to build request to model")]
RequestError(#[from] hyper::http::Error),
#[error("failed to transcode prompt or response")]
TranscodingError(#[from] serde_json::Error),
}
pub struct ToolParameter<'a> {
pub name: &'a str,
pub description: &'a str,
pub parameters: &'a serde_json::Value,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ToolParameters {
inner: schemars::schema::Schema,
}
impl ToolParameters {
pub fn new<S: schemars::JsonSchema>() -> Self {
let mut generator = schemars::gen::SchemaGenerator::default();
Self {
inner: <S as schemars::JsonSchema>::json_schema(&mut generator),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Tool {
pub name: String,
pub description: String,
pub parameters: ToolParameters,
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
impl ReasoningEffort {
fn max_tokens(&self) -> usize {
match self {
Self::Low => 1024,
Self::Medium => 2048,
Self::High => 4096,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PromptOptions {
pub max_tokens: usize,
pub temperature: f32,
pub system_prompt: Option<String>,
pub stopping_sequences: Vec<String>,
pub tools: Vec<Tool>,
pub reasoning: Option<ReasoningEffort>,
}
impl Default for PromptOptions {
fn default() -> Self {
Self {
max_tokens: DEFAULT_MAX_TOKENS,
temperature: DEFAULT_TEMPERATURE,
system_prompt: None,
stopping_sequences: vec![],
tools: vec![],
reasoning: None,
}
}
}
impl PromptOptions {
pub fn set_max_tokens(&mut self, max_tokens: usize) -> &mut Self {
self.max_tokens = max_tokens;
self
}
pub fn set_temperature(&mut self, temperature: f32) -> &mut Self {
self.temperature = temperature;
self
}
pub fn set_system_prompt(&mut self, system_prompt: String) -> &mut Self {
self.system_prompt = Some(system_prompt);
self
}
pub fn set_stopping_sequences(&mut self, stopping_sequences: Vec<String>) -> &mut Self {
self.stopping_sequences = stopping_sequences;
self
}
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
pub fn temperature(&self) -> f32 {
self.temperature
}
pub fn system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
pub fn stopping_sequences(&self) -> &[String] {
&self.stopping_sequences[..]
}
}
pub struct SerializedJson {
raw: serde_json::Value,
serialized: String,
}
impl SerializedJson {
pub fn try_new(value: serde_json::Value) -> serde_json::Result<Self> {
Ok(Self {
serialized: serde_json::to_string(&value)?,
raw: value,
})
}
}
pub enum Message {
User(String),
Assistant(String),
ToolRequest {
id: String,
name: String,
arguments: SerializedJson,
},
ToolResponse {
content: String,
id: String,
},
}
pub trait LLM {
type TokenStream: futures::Stream<Item = Result<Chunk, TokenError>> + Send;
fn prompt(
&self,
messages: &[Message],
options: &PromptOptions,
) -> Result<Self::TokenStream, PromptError>;
}
mod sealed {
pub trait TokenStreamExtSealed {}
impl<T> TokenStreamExtSealed for T where
T: futures::Stream<Item = Result<super::Chunk, super::TokenError>> + Send
{
}
}
pub trait TokenStreamExt: sealed::TokenStreamExtSealed {
fn all_tokens(self)
-> impl std::future::Future<Output = Result<Vec<Chunk>, TokenError>> + Send;
}
impl<T> TokenStreamExt for T
where
T: sealed::TokenStreamExtSealed + futures::Stream<Item = Result<Chunk, TokenError>> + Send,
{
async fn all_tokens(self) -> Result<Vec<Chunk>, TokenError> {
use futures::StreamExt;
let mut stream = Box::pin(self);
let mut acc = vec![];
while let Some(token) = stream.next().await {
tracing::debug!("received token in all_tokens: {:?}", token);
if let Some(last_acc) = acc.last_mut() {
match (last_acc, token?) {
(Chunk::Token(lhs), Chunk::Token(rhs)) => lhs.push_str(&rhs),
(Chunk::Thinking(lhs), Chunk::Thinking(rhs)) => lhs.push_str(&rhs),
(Chunk::ToolCall(lhs), Chunk::ToolCall(rhs))
if lhs.id.as_ref().is_none_or(|lhs_id| {
rhs.id.as_ref().is_none_or(|rhs_id| lhs_id == rhs_id)
}) =>
{
lhs.id = lhs.id.take().or(rhs.id);
lhs.name = lhs.name.take().or(rhs.name);
lhs.arguments.push_str(&rhs.arguments);
}
(_, token) => acc.push(token),
}
} else {
acc.push(token?);
};
}
Ok(acc)
}
}
#[derive(Debug, Clone)]
pub struct ToolCallChunk {
pub id: Option<String>,
pub name: Option<String>,
pub arguments: String,
}
#[derive(Debug, Clone)]
pub enum Chunk {
Token(String),
Thinking(String),
ToolCall(ToolCallChunk),
}
impl Chunk {
pub fn try_into_message(self) -> Option<Message> {
match self {
Chunk::Token(content) => Some(Message::Assistant(content)),
Chunk::Thinking(_) => None,
Chunk::ToolCall(tool_call_chunk) => Some(Message::ToolRequest {
id: tool_call_chunk.id?,
name: tool_call_chunk.name?,
arguments: SerializedJson::try_new(
serde_json::from_str::<serde_json::Value>(&tool_call_chunk.arguments).ok()?,
)
.ok()?,
}),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum TokenError {
#[error("the connection was lost")]
ConnectionLost(#[from] sse::Error),
#[error("the server responded with an unknown event type `{0}`")]
UnknownEventType(String),
#[error("the server responded with unexpected data: {message}")]
MalformedResponse {
message: &'static str,
value: serde_json::Value,
},
}
pub use schemars::JsonSchema;
pub use serde;
pub use serde_json;
pub use sse::Error as SseError;
trait JsonExt {
fn take_str(&mut self) -> Option<String>;
}
impl JsonExt for serde_json::Value {
fn take_str(&mut self) -> Option<String> {
if let serde_json::Value::String(s) = self.take() {
Some(s)
} else {
None
}
}
}