extern crate self as menta;
pub mod providers;
pub use menta_derive::Tool;
use futures_util::{Stream, stream};
use schemars::JsonSchema;
use schemars::schema::{
InstanceType, RootSchema, Schema, SchemaObject, SingleOrVec, SubschemaValidation,
};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::{Value, json};
use std::any::TypeId;
use std::collections::BTreeMap;
use std::error::Error as StdError;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
static NEXT_TOOL_CALL_ID: AtomicUsize = AtomicUsize::new(1);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Error {
InvalidModelId(String),
UnknownProvider(String),
UnsupportedModel(String),
MissingEnvironmentVariable(&'static str),
Http(String),
Api(String),
Json(String),
Parse(String),
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidModelId(model) => write!(f, "invalid model id: {model}"),
Self::UnknownProvider(provider) => write!(f, "unknown provider: {provider}"),
Self::UnsupportedModel(model) => write!(f, "unsupported model: {model}"),
Self::MissingEnvironmentVariable(name) => {
write!(f, "missing environment variable: {name}")
}
Self::Http(error) => write!(f, "http error: {error}"),
Self::Api(error) => write!(f, "api error: {error}"),
Self::Json(error) => write!(f, "json error: {error}"),
Self::Parse(error) => write!(f, "parse error: {error}"),
}
}
}
impl StdError for Error {}
pub type Result<T> = std::result::Result<T, Error>;
pub struct ProviderRegistration {
pub id: &'static str,
pub language_model: fn(&str) -> Result<Box<dyn LanguageModel>>,
pub embedding_model: fn(&str) -> Result<Box<dyn EmbeddingModel>>,
}
inventory::collect!(ProviderRegistration);
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub input: String,
}
impl ToolCall {
pub fn new(name: impl Into<String>, input: impl Into<String>) -> Self {
let id = format!(
"tool-call-{}",
NEXT_TOOL_CALL_ID.fetch_add(1, Ordering::Relaxed)
);
Self {
id,
name: name.into(),
input: input.into(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ToolResult {
pub call_id: String,
pub name: String,
pub output: String,
pub is_error: bool,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ToolSchema {
String {
description: Option<String>,
},
Integer {
description: Option<String>,
},
Number {
description: Option<String>,
},
Boolean {
description: Option<String>,
},
Array {
description: Option<String>,
items: Box<ToolSchema>,
},
Object(ToolObjectSchema),
}
impl ToolSchema {
pub fn string() -> Self {
Self::String { description: None }
}
pub fn with_description(self, description: impl Into<String>) -> Self {
let description = Some(description.into());
match self {
Self::String { .. } => Self::String { description },
Self::Integer { .. } => Self::Integer { description },
Self::Number { .. } => Self::Number { description },
Self::Boolean { .. } => Self::Boolean { description },
Self::Array { items, .. } => Self::Array { description, items },
Self::Object(mut object) => {
object.description = description;
Self::Object(object)
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ToolObjectSchema {
pub description: Option<String>,
pub fields: Vec<ToolFieldSchema>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ToolFieldSchema {
pub name: String,
pub description: Option<String>,
pub schema: ToolSchema,
pub required: bool,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Part {
Text(String),
ToolCall(ToolCall),
ToolResult(ToolResult),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModelMessage {
pub role: Role,
pub parts: Vec<Part>,
}
impl ModelMessage {
pub fn system(text: impl Into<String>) -> Self {
Self {
role: Role::System,
parts: vec![Part::Text(text.into())],
}
}
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
parts: vec![Part::Text(text.into())],
}
}
pub fn assistant_text(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
parts: vec![Part::Text(text.into())],
}
}
pub fn assistant_parts(parts: Vec<Part>) -> Self {
Self {
role: Role::Assistant,
parts,
}
}
pub fn tool_result(result: ToolResult) -> Self {
Self {
role: Role::Tool,
parts: vec![Part::ToolResult(result)],
}
}
pub fn text(&self) -> String {
self.parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ModelSettings {
pub temperature: Option<f32>,
pub max_output_tokens: Option<usize>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ToolChoice {
Auto,
None,
Required(String),
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Auto
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: ToolSchema,
pub output_schema: Option<ToolSchema>,
}
type ToolFuture = Pin<Box<dyn Future<Output = ToolResult> + Send + 'static>>;
type ModelFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + Send + 'a>>;
pub type StreamTextStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send + 'static>>;
#[derive(Clone)]
pub struct Tool {
pub definition: ToolDefinition,
executor: Arc<dyn Fn(ToolCall) -> ToolFuture + Send + Sync>,
}
impl std::fmt::Debug for Tool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tool")
.field("definition", &self.definition)
.finish_non_exhaustive()
}
}
impl Tool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
input_schema: ToolSchema,
executor: impl Fn(&str) -> String + Send + Sync + 'static,
) -> Self {
Self::new_async(name, description, input_schema, None, move |input| {
let output = executor(&input);
async move { Ok(output) }
})
}
pub fn new_async<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
input_schema: ToolSchema,
output_schema: Option<ToolSchema>,
executor: F,
) -> Self
where
F: Fn(String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = std::result::Result<String, String>> + Send + 'static,
{
let name = name.into();
let description = description.into();
let executor = Arc::new(executor);
Self {
definition: ToolDefinition {
name: name.clone(),
description,
input_schema,
output_schema,
},
executor: Arc::new(move |call| {
let executor = Arc::clone(&executor);
let tool_name = name.clone();
Box::pin(async move {
match executor(call.input.clone()).await {
Ok(output) => ToolResult {
call_id: call.id,
name: tool_name,
output,
is_error: false,
},
Err(output) => ToolResult {
call_id: call.id,
name: tool_name,
output,
is_error: true,
},
}
})
}),
}
}
pub async fn execute(&self, call: &ToolCall) -> ToolResult {
(self.executor)(call.clone()).await
}
pub fn from_execute<T>() -> Self
where
T: ToolExecute,
{
let definition = ToolDefinition {
name: T::tool_name().to_string(),
description: T::tool_description().to_string(),
input_schema: T::input_schema(),
output_schema: T::output_schema(),
};
Self::new_async(
definition.name.clone(),
definition.description.clone(),
definition.input_schema.clone(),
definition.output_schema.clone(),
move |input| async move {
let parsed = serde_json::from_str::<T>(&input).map_err(|error| {
format!("invalid tool input for {}: {error}", T::tool_name())
})?;
let output = parsed.execute().await?;
serde_json::to_string(&output)
.map_err(|error| format!("invalid tool output for {}: {error}", T::tool_name()))
},
)
}
}
pub trait ToolInput {
fn tool_name() -> &'static str;
fn tool_description() -> &'static str;
fn input_schema() -> ToolSchema;
fn definition() -> ToolDefinition {
ToolDefinition {
name: Self::tool_name().to_string(),
description: Self::tool_description().to_string(),
input_schema: Self::input_schema(),
output_schema: None,
}
}
}
pub trait ToolExecute: ToolInput + DeserializeOwned + Send + Sync + 'static {
type Output: Serialize + Send + Sync + 'static;
fn output_schema() -> Option<ToolSchema> {
None
}
fn execute(&self) -> impl Future<Output = std::result::Result<Self::Output, String>> + Send;
fn tool() -> Tool
where
Self: Sized,
{
Tool::from_execute::<Self>()
}
}
pub fn tool<T>() -> Tool
where
T: ToolExecute,
{
Tool::from_execute::<T>()
}
#[macro_export]
macro_rules! tools {
($($tool:ty),* $(,)?) => {
vec![$($crate::tool::<$tool>()),*]
};
}
#[derive(Clone, Debug)]
pub struct GenerateTextRequest<T = String> {
pub model: String,
pub system: Option<String>,
pub prompt: Option<String>,
pub messages: Vec<ModelMessage>,
pub tools: Vec<Tool>,
pub tool_choice: ToolChoice,
pub max_steps: usize,
pub settings: ModelSettings,
pub provider_options: BTreeMap<String, String>,
pub _output: PhantomData<fn() -> T>,
}
impl<T> Default for GenerateTextRequest<T> {
fn default() -> Self {
Self {
model: String::new(),
system: None,
prompt: None,
messages: Vec::new(),
tools: Vec::new(),
tool_choice: ToolChoice::Auto,
max_steps: 0,
settings: ModelSettings::default(),
provider_options: BTreeMap::new(),
_output: PhantomData,
}
}
}
impl<T> GenerateTextRequest<T> {
pub fn typed() -> Self {
Self::default()
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
pub fn messages(mut self, messages: Vec<ModelMessage>) -> Self {
self.messages = messages;
self
}
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = tools;
self
}
pub fn tool<U>(mut self) -> Self
where
U: ToolExecute,
{
self.tools.push(crate::tool::<U>());
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = tool_choice;
self
}
pub fn max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = max_steps;
self
}
pub fn settings(mut self, settings: ModelSettings) -> Self {
self.settings = settings;
self
}
pub fn provider_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.provider_options.insert(key.into(), value.into());
self
}
fn into_parts(self, output: &Output) -> (ModelRequest, Vec<Tool>) {
let mut messages = self.messages;
let tools = self.tools;
if messages.is_empty() {
if let Some(prompt) = self.prompt {
messages.push(ModelMessage::user(prompt));
}
}
if let Some(instruction) = output.instruction() {
messages.insert(0, ModelMessage::system(instruction));
}
if let Some(system) = self.system {
messages.insert(0, ModelMessage::system(system));
}
(
ModelRequest {
messages,
tools: tools.iter().map(|tool| tool.definition.clone()).collect(),
tool_choice: self.tool_choice,
settings: self.settings,
provider_options: self.provider_options,
},
tools,
)
}
}
impl GenerateTextRequest<String> {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Clone, Debug)]
pub struct ModelRequest {
pub messages: Vec<ModelMessage>,
pub tools: Vec<ToolDefinition>,
pub tool_choice: ToolChoice,
pub settings: ModelSettings,
pub provider_options: BTreeMap<String, String>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Usage {
pub input_tokens: usize,
pub output_tokens: usize,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FinishReason {
Stop,
ToolCalls,
Length,
Error,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModelResponse {
pub parts: Vec<Part>,
pub finish_reason: FinishReason,
pub usage: Usage,
pub response_metadata: BTreeMap<String, String>,
}
impl ModelResponse {
pub fn text(&self) -> String {
self.parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
pub fn tool_calls(&self) -> Vec<ToolCall> {
self.parts
.iter()
.filter_map(|part| match part {
Part::ToolCall(call) => Some(call.clone()),
_ => None,
})
.collect()
}
fn into_stream(self) -> Vec<StreamEvent> {
let mut events = Vec::new();
for part in &self.parts {
match part {
Part::Text(text) => events.push(StreamEvent::TextDelta(text.clone())),
Part::ToolCall(call) => events.push(StreamEvent::ToolCall(call.clone())),
Part::ToolResult(result) => events.push(StreamEvent::ToolResult(result.clone())),
}
}
events.push(StreamEvent::Finish {
reason: self.finish_reason,
usage: self.usage,
parts: self.parts,
});
events
}
}
fn model_response_stream(response: ModelResponse) -> StreamTextStream {
Box::pin(stream::iter(response.into_stream()))
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum StreamEvent {
TextDelta(String),
ToolCall(ToolCall),
ToolResult(ToolResult),
Finish {
reason: FinishReason,
usage: Usage,
parts: Vec<Part>,
},
Error(String),
}
pub trait LanguageModel: Sync {
fn model_id(&self) -> &str;
fn generate<'a>(&'a self, request: &'a ModelRequest) -> ModelFuture<'a, ModelResponse>;
fn stream<'a>(&'a self, request: &'a ModelRequest) -> ModelFuture<'a, StreamTextStream> {
Box::pin(async move { Ok(model_response_stream(self.generate(request).await?)) })
}
}
pub trait EmbeddingModel {
fn model_id(&self) -> &str;
fn embed(&self, value: &str) -> Result<EmbeddingResult>;
fn embed_many(&self, values: &[String]) -> Result<BatchEmbeddingResult> {
let mut usage = Usage::default();
let mut embeddings = Vec::with_capacity(values.len());
for value in values {
let result = self.embed(value)?;
usage.input_tokens += result.usage.input_tokens;
embeddings.push(result.embedding);
}
Ok(BatchEmbeddingResult { embeddings, usage })
}
}
pub fn registered_providers() -> Vec<&'static str> {
let mut providers = inventory::iter::<ProviderRegistration>
.into_iter()
.map(|provider| provider.id)
.collect::<Vec<_>>();
providers.sort_unstable();
providers.dedup();
providers
}
fn split_model_id(model: &str) -> Result<(&str, &str)> {
model
.split_once('/')
.ok_or_else(|| Error::InvalidModelId(model.to_string()))
}
fn resolve_language_model(model: &str) -> Result<Box<dyn LanguageModel>> {
let (provider_id, model_id) = split_model_id(model)?;
for provider in inventory::iter::<ProviderRegistration> {
if provider.id == provider_id {
return (provider.language_model)(model_id);
}
}
Err(Error::UnknownProvider(provider_id.to_string()))
}
fn resolve_embedding_model(model: &str) -> Result<Box<dyn EmbeddingModel>> {
let (provider_id, model_id) = split_model_id(model)?;
for provider in inventory::iter::<ProviderRegistration> {
if provider.id == provider_id {
return (provider.embedding_model)(model_id);
}
}
Err(Error::UnknownProvider(provider_id.to_string()))
}
#[derive(Clone, Debug, PartialEq)]
pub struct GenerateTextResult<T> {
pub text: String,
pub output: T,
pub parts: Vec<Part>,
pub tool_calls: Vec<ToolCall>,
pub finish_reason: FinishReason,
pub usage: Usage,
pub response_metadata: BTreeMap<String, String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Output {
Text,
Object {
name: Option<String>,
description: Option<String>,
schema: ToolSchema,
},
Array {
name: Option<String>,
description: Option<String>,
element: ToolSchema,
},
Choice {
name: Option<String>,
description: Option<String>,
options: Vec<String>,
},
Json {
name: Option<String>,
description: Option<String>,
},
}
impl Output {
pub fn text() -> Self {
Self::Text
}
pub fn object<T: JsonSchema>() -> Self {
Self::object_schema(tool_schema_for::<T>())
}
pub fn object_schema(schema: ToolSchema) -> Self {
Self::Object {
name: None,
description: None,
schema,
}
}
pub fn array<T: JsonSchema>() -> Self {
Self::array_schema(tool_schema_for::<T>())
}
pub fn array_schema(element: ToolSchema) -> Self {
Self::Array {
name: None,
description: None,
element,
}
}
pub fn choice(options: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::Choice {
name: None,
description: None,
options: options.into_iter().map(Into::into).collect(),
}
}
pub fn json() -> Self {
Self::Json {
name: None,
description: None,
}
}
pub fn with_name(self, name: impl Into<String>) -> Self {
let name = Some(name.into());
match self {
Self::Text => Self::Text,
Self::Object {
description,
schema,
..
} => Self::Object {
name,
description,
schema,
},
Self::Array {
description,
element,
..
} => Self::Array {
name,
description,
element,
},
Self::Choice {
description,
options,
..
} => Self::Choice {
name,
description,
options,
},
Self::Json { description, .. } => Self::Json { name, description },
}
}
pub fn with_description(self, description: impl Into<String>) -> Self {
let description = Some(description.into());
match self {
Self::Text => Self::Text,
Self::Object { name, schema, .. } => Self::Object {
name,
description,
schema,
},
Self::Array { name, element, .. } => Self::Array {
name,
description,
element,
},
Self::Choice { name, options, .. } => Self::Choice {
name,
description,
options,
},
Self::Json { name, .. } => Self::Json { name, description },
}
}
fn instruction(&self) -> Option<String> {
match self {
Self::Text => None,
Self::Object {
name,
description,
schema,
} => Some(format!(
"Return ONLY valid JSON for an object output.{}{} Schema: {}",
output_name_instruction(name),
output_description_instruction(description),
tool_schema_json(schema),
)),
Self::Array {
name,
description,
element,
} => Some(format!(
"Return ONLY valid JSON for an array output.{}{} Each element must match: {}",
output_name_instruction(name),
output_description_instruction(description),
tool_schema_json(element),
)),
Self::Choice {
name,
description,
options,
} => Some(format!(
"Return ONLY one of these exact strings: {}.{}{}",
options.join(", "),
output_name_instruction(name),
output_description_instruction(description),
)),
Self::Json { name, description } => Some(format!(
"Return ONLY valid JSON.{}{}",
output_name_instruction(name),
output_description_instruction(description),
)),
}
}
}
async fn generate_text_with_model<T: DeserializeOwned + JsonSchema + 'static>(
model: &(impl LanguageModel + ?Sized),
request: GenerateTextRequest<T>,
) -> Result<GenerateTextResult<T>> {
let output = output_for::<T>();
generate_text_with_output::<T>(model, request, output).await
}
async fn generate_text_with_output<T: DeserializeOwned>(
model: &(impl LanguageModel + ?Sized),
request: GenerateTextRequest<T>,
output: Output,
) -> Result<GenerateTextResult<T>> {
let max_steps = request.max_steps.max(1);
let (mut model_request, tools) = request.into_parts(&output);
let mut total_usage = Usage::default();
for step in 0..max_steps {
let response = model.generate(&model_request).await?;
total_usage.input_tokens += response.usage.input_tokens;
total_usage.output_tokens += response.usage.output_tokens;
let tool_calls = response.tool_calls();
if tool_calls.is_empty() {
let text = response.text();
return Ok(GenerateTextResult {
output: parse_output::<T>(&output, &text)?,
text,
tool_calls,
finish_reason: response.finish_reason,
usage: total_usage,
response_metadata: response.response_metadata,
parts: response.parts,
});
}
if step + 1 == max_steps {
let text = response.text();
return Ok(GenerateTextResult {
output: parse_output::<T>(&output, &text)?,
text,
tool_calls,
finish_reason: FinishReason::Length,
usage: total_usage,
response_metadata: response.response_metadata,
parts: response.parts,
});
}
model_request
.messages
.push(ModelMessage::assistant_parts(response.parts.clone()));
for call in tool_calls {
match tools.iter().find(|tool| tool.definition.name == call.name) {
Some(tool) => {
let result = tool.execute(&call).await;
model_request
.messages
.push(ModelMessage::tool_result(result));
}
None => {
let result = ToolResult {
call_id: call.id,
name: call.name,
output: "tool not found".to_string(),
is_error: true,
};
model_request
.messages
.push(ModelMessage::tool_result(result));
}
}
}
if matches!(model_request.tool_choice, ToolChoice::Required(_)) {
model_request.tool_choice = ToolChoice::Auto;
}
}
unreachable!("generate_text loop must return before exhaustion")
}
pub async fn generate_text<T: DeserializeOwned + JsonSchema + 'static>(
request: GenerateTextRequest<T>,
) -> Result<GenerateTextResult<T>> {
let model = resolve_language_model(&request.model)?;
generate_text_with_model(model.as_ref(), request).await
}
pub async fn stream_text<T: JsonSchema + 'static>(
request: GenerateTextRequest<T>,
) -> Result<StreamTextStream> {
let model = resolve_language_model(&request.model)?;
let output = output_for::<T>();
let (model_request, _) = request.into_parts(&output);
model.stream(&model_request).await
}
#[derive(Clone, Debug, PartialEq)]
pub struct EmbeddingResult {
pub embedding: Vec<f32>,
pub usage: Usage,
}
#[derive(Clone, Debug, PartialEq)]
pub struct BatchEmbeddingResult {
pub embeddings: Vec<Vec<f32>>,
pub usage: Usage,
}
pub fn embed(model: &str, value: &str) -> Result<EmbeddingResult> {
resolve_embedding_model(model)?.embed(value)
}
pub fn embed_many(model: &str, values: &[impl AsRef<str>]) -> Result<BatchEmbeddingResult> {
let values = values
.iter()
.map(|value| value.as_ref().to_string())
.collect::<Vec<_>>();
resolve_embedding_model(model)?.embed_many(&values)
}
pub fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
let dot = left
.iter()
.zip(right.iter())
.map(|(a, b)| a * b)
.sum::<f32>();
let left_norm = left.iter().map(|v| v * v).sum::<f32>().sqrt();
let right_norm = right.iter().map(|v| v * v).sum::<f32>().sqrt();
if left_norm == 0.0 || right_norm == 0.0 {
return 0.0;
}
dot / (left_norm * right_norm)
}
pub(crate) fn metadata_with_provider(provider: &str, model_id: &str) -> BTreeMap<String, String> {
let mut metadata = BTreeMap::new();
metadata.insert("provider".to_string(), provider.to_string());
metadata.insert("model_id".to_string(), model_id.to_string());
metadata
}
pub(crate) fn estimate_usage(messages: &[ModelMessage], output: &str) -> Usage {
Usage {
input_tokens: messages
.iter()
.map(ModelMessage::text)
.map(|text| count_tokens(&text))
.sum(),
output_tokens: count_tokens(output),
}
}
pub(crate) fn count_tokens(value: &str) -> usize {
value.split_whitespace().count()
}
pub fn rank_by_similarity<'a>(
query: &[f32],
values: impl IntoIterator<Item = (&'a str, &'a [f32])>,
) -> Vec<(&'a str, f32)> {
let mut scored = values
.into_iter()
.map(|(value, embedding)| (value, cosine_similarity(query, embedding)))
.collect::<Vec<_>>();
scored.sort_by(|left, right| right.1.total_cmp(&left.1));
scored
}
pub fn provider_registry() -> Vec<&'static str> {
registered_providers()
}
fn output_name_instruction(name: &Option<String>) -> String {
name.as_ref()
.map(|name| format!(" Name: {name}."))
.unwrap_or_default()
}
fn output_description_instruction(description: &Option<String>) -> String {
description
.as_ref()
.map(|description| format!(" Description: {description}."))
.unwrap_or_default()
}
fn parse_output<T: DeserializeOwned>(output: &Output, text: &str) -> Result<T> {
match output {
Output::Text => serde_json::from_value(Value::String(text.to_string()))
.map_err(|error| Error::Parse(error.to_string())),
Output::Json { .. } => serde_json::from_value(parse_json(text)?)
.map_err(|error| Error::Parse(error.to_string())),
Output::Object { schema, .. } => {
let value = parse_json(text)?;
validate_against_schema(&value, schema)?;
serde_json::from_value(value).map_err(|error| Error::Parse(error.to_string()))
}
Output::Array { element, .. } => {
let value = parse_json(text)?;
let items = value
.as_array()
.ok_or_else(|| Error::Parse("expected a JSON array".to_string()))?;
for item in items {
validate_against_schema(item, element)?;
}
serde_json::from_value(value).map_err(|error| Error::Parse(error.to_string()))
}
Output::Choice { options, .. } => {
let choice = text.trim().trim_matches('"').to_string();
if options.iter().any(|option| option == &choice) {
serde_json::from_value(Value::String(choice))
.map_err(|error| Error::Parse(error.to_string()))
} else {
Err(Error::Parse(format!("invalid choice output: {choice}")))
}
}
}
}
fn output_for<T: JsonSchema + 'static>() -> Output {
if TypeId::of::<T>() == TypeId::of::<String>() {
Output::text()
} else if TypeId::of::<T>() == TypeId::of::<Value>() {
Output::json()
} else {
Output::object::<T>()
}
}
fn parse_json(text: &str) -> Result<Value> {
serde_json::from_str(text).map_err(|error| Error::Parse(error.to_string()))
}
fn validate_against_schema(value: &Value, schema: &ToolSchema) -> Result<()> {
match schema {
ToolSchema::String { .. } => {
if value.is_string() {
Ok(())
} else {
Err(Error::Parse("expected string".to_string()))
}
}
ToolSchema::Integer { .. } => {
if value.as_i64().is_some() || value.as_u64().is_some() {
Ok(())
} else {
Err(Error::Parse("expected integer".to_string()))
}
}
ToolSchema::Number { .. } => {
if value.as_f64().is_some() || value.as_i64().is_some() || value.as_u64().is_some() {
Ok(())
} else {
Err(Error::Parse("expected number".to_string()))
}
}
ToolSchema::Boolean { .. } => {
if value.is_boolean() {
Ok(())
} else {
Err(Error::Parse("expected boolean".to_string()))
}
}
ToolSchema::Array { items, .. } => {
let array = value
.as_array()
.ok_or_else(|| Error::Parse("expected array".to_string()))?;
for item in array {
validate_against_schema(item, items)?;
}
Ok(())
}
ToolSchema::Object(object) => {
let map = value
.as_object()
.ok_or_else(|| Error::Parse("expected object".to_string()))?;
for field in &object.fields {
match map.get(&field.name) {
Some(field_value) => validate_against_schema(field_value, &field.schema)?,
None if field.required => {
return Err(Error::Parse(format!(
"missing required field: {}",
field.name
)));
}
None => {}
}
}
Ok(())
}
}
}
fn tool_schema_for<T: JsonSchema>() -> ToolSchema {
let root = schemars::schema_for!(T);
tool_schema_from_root(&root)
}
fn tool_schema_from_root(root: &RootSchema) -> ToolSchema {
tool_schema_from_schema_object(&root.schema, root)
}
fn tool_schema_from_schema(schema: &Schema, root: &RootSchema) -> ToolSchema {
match schema {
Schema::Object(object) => tool_schema_from_schema_object(object, root),
Schema::Bool(true) => ToolSchema::string(),
Schema::Bool(false) => ToolSchema::string(),
}
}
fn tool_schema_from_schema_object(object: &SchemaObject, root: &RootSchema) -> ToolSchema {
if let Some(reference) = &object.reference {
if let Some(definition) = reference.strip_prefix("#/definitions/") {
if let Some(schema) = root.definitions.get(definition) {
return tool_schema_from_schema(schema, root);
}
}
}
if let Some(subschemas) = &object.subschemas {
if let Some(schema) = pick_subschema(subschemas, root) {
return schema;
}
}
let description = object
.metadata
.as_ref()
.and_then(|metadata| metadata.description.clone());
if let Some(instance_type) = &object.instance_type {
return match pick_instance_type(instance_type) {
Some(InstanceType::String) => ToolSchema::String { description },
Some(InstanceType::Integer) => ToolSchema::Integer { description },
Some(InstanceType::Number) => ToolSchema::Number { description },
Some(InstanceType::Boolean) => ToolSchema::Boolean { description },
Some(InstanceType::Array) => {
let items = object
.array
.as_ref()
.and_then(|array| array.items.as_ref())
.map(|items| match items {
SingleOrVec::Single(schema) => tool_schema_from_schema(schema, root),
SingleOrVec::Vec(items) => items
.first()
.map(|schema| tool_schema_from_schema(schema, root))
.unwrap_or_else(ToolSchema::string),
})
.unwrap_or_else(ToolSchema::string);
ToolSchema::Array {
description,
items: Box::new(items),
}
}
Some(InstanceType::Object) => {
let validation = object.object.as_ref();
let required = validation
.map(|object| object.required.clone())
.unwrap_or_default();
let fields = validation
.map(|object| {
object
.properties
.iter()
.map(|(name, schema)| ToolFieldSchema {
name: name.clone(),
description: schema_description(schema),
schema: tool_schema_from_schema(schema, root),
required: required.contains(name),
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
ToolSchema::Object(ToolObjectSchema {
description,
fields,
})
}
_ => ToolSchema::string().with_description(description.unwrap_or_default()),
};
}
ToolSchema::Object(ToolObjectSchema {
description,
fields: Vec::new(),
})
}
fn schema_description(schema: &Schema) -> Option<String> {
match schema {
Schema::Object(object) => object
.metadata
.as_ref()
.and_then(|metadata| metadata.description.clone()),
Schema::Bool(_) => None,
}
}
fn pick_instance_type(instance_type: &SingleOrVec<InstanceType>) -> Option<InstanceType> {
match instance_type {
SingleOrVec::Single(instance) => Some(**instance),
SingleOrVec::Vec(instances) => instances
.iter()
.find(|instance| **instance != InstanceType::Null)
.copied(),
}
}
fn pick_subschema(subschemas: &SubschemaValidation, root: &RootSchema) -> Option<ToolSchema> {
subschemas
.any_of
.as_ref()
.and_then(|schemas| schemas.first())
.or_else(|| {
subschemas
.one_of
.as_ref()
.and_then(|schemas| schemas.first())
})
.or_else(|| {
subschemas
.all_of
.as_ref()
.and_then(|schemas| schemas.first())
})
.map(|schema| tool_schema_from_schema(schema, root))
}
fn tool_schema_json(schema: &ToolSchema) -> Value {
match schema {
ToolSchema::String { description } => {
json_with_description(json!({ "type": "string" }), description)
}
ToolSchema::Integer { description } => {
json_with_description(json!({ "type": "integer" }), description)
}
ToolSchema::Number { description } => {
json_with_description(json!({ "type": "number" }), description)
}
ToolSchema::Boolean { description } => {
json_with_description(json!({ "type": "boolean" }), description)
}
ToolSchema::Array { description, items } => json_with_description(
json!({ "type": "array", "items": tool_schema_json(items) }),
description,
),
ToolSchema::Object(object) => {
let properties = object
.fields
.iter()
.map(|field| {
let mut schema = tool_schema_json(&field.schema);
if let Some(description) = &field.description {
schema["description"] = json!(description);
}
(field.name.clone(), schema)
})
.collect::<serde_json::Map<String, Value>>();
let required = object
.fields
.iter()
.filter(|field| field.required)
.map(|field| Value::String(field.name.clone()))
.collect::<Vec<_>>();
json_with_description(
json!({
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": false,
}),
&object.description,
)
}
}
}
fn json_with_description(mut value: Value, description: &Option<String>) -> Value {
if let Some(description) = description {
value["description"] = json!(description);
}
value
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[tokio::test]
async fn generate_text_normalizes_a_prompt_into_messages() {
let result = generate_text::<String>(GenerateTextRequest {
model: "mock/mock-1".to_string(),
prompt: Some("hello llm".to_string()),
..Default::default()
})
.await
.unwrap();
assert_eq!(result.text, "Mock response from mock-1: hello llm");
assert_eq!(result.output, "Mock response from mock-1: hello llm");
assert_eq!(result.finish_reason, FinishReason::Stop);
}
#[tokio::test]
async fn stream_text_surfaces_tool_calls() {
let mut events = stream_text::<String>(GenerateTextRequest {
model: "mock/mock-1".to_string(),
prompt: Some("remember rust uses cargo".to_string()),
tools: vec![Tool::new(
"add_resource",
"save knowledge",
ToolSchema::string(),
|input| format!("stored: {input}"),
)],
..Default::default()
})
.await
.unwrap();
let events = futures_util::StreamExt::collect::<Vec<_>>(&mut events).await;
assert!(matches!(events.first(), Some(StreamEvent::ToolCall(_))));
assert!(matches!(events.last(), Some(StreamEvent::Finish { .. })));
}
#[tokio::test]
async fn generate_text_executes_tool_calls_until_a_final_answer_is_available() {
let store = Arc::new(Mutex::new(Vec::<String>::new()));
let store_for_add = Arc::clone(&store);
let add_resource = Tool::new(
"add_resource",
"save knowledge",
ToolSchema::string(),
move |input| {
store_for_add.lock().unwrap().push(input.to_string());
format!("stored: {input}")
},
);
let result = generate_text::<String>(GenerateTextRequest {
model: "mock/mock-1".to_string(),
prompt: Some("remember Vercel AI SDK inspired this PoC".to_string()),
tools: vec![add_resource],
max_steps: 2,
..Default::default()
})
.await
.unwrap();
assert_eq!(result.finish_reason, FinishReason::Stop);
assert!(
result
.text
.contains("stored: Vercel AI SDK inspired this PoC")
);
assert_eq!(store.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn generate_text_parses_structured_output_via_output() {
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct Status {
topic: String,
status: String,
}
let result = generate_text::<Status>(GenerateTextRequest {
model: "mock/mock-1".to_string(),
prompt: Some("return json".to_string()),
..Default::default()
})
.await
.unwrap();
assert_eq!(result.output.topic, "llm");
assert_eq!(result.output.status, "ok");
}
#[test]
fn embeddings_support_similarity_ranking() {
let query = embed("mock/embed-1", "rust cargo").unwrap();
let docs = vec![
(
"rust crates",
embed("mock/embed-1", "rust crates").unwrap().embedding,
),
(
"garden flowers",
embed("mock/embed-1", "garden flowers").unwrap().embedding,
),
(
"cargo workspace",
embed("mock/embed-1", "cargo workspace").unwrap().embedding,
),
];
let ranked = rank_by_similarity(
&query.embedding,
docs.iter()
.map(|(name, embedding)| (*name, embedding.as_slice())),
);
assert_eq!(ranked[0].0, "cargo workspace");
assert!(ranked[0].1 >= ranked[1].1);
}
#[tokio::test]
async fn derive_tool_generates_schema_and_executes() {
#[derive(serde::Deserialize, Tool)]
#[tool(description = "Get the weather in a location")]
struct WeatherTool {
#[description = "The location to get the weather for"]
location: String,
}
#[derive(Debug, PartialEq, serde::Serialize)]
struct WeatherOutput {
temperature: i32,
conditions: String,
}
impl ToolExecute for WeatherTool {
type Output = WeatherOutput;
async fn execute(&self) -> std::result::Result<Self::Output, String> {
Ok(WeatherOutput {
temperature: 72,
conditions: format!("sunny in {}", self.location),
})
}
}
let definition = WeatherTool::definition();
assert_eq!(definition.name, "weather");
assert_eq!(definition.description, "Get the weather in a location");
assert!(matches!(definition.input_schema, ToolSchema::Object(_)));
let tool = WeatherTool::tool();
let call = ToolCall::new("weather", r#"{"location":"Paris"}"#);
let result = tool.execute(&call).await;
assert!(!result.is_error);
assert_eq!(
result.output,
r#"{"temperature":72,"conditions":"sunny in Paris"}"#
);
}
}