use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
extractor::ExtractorBuilder,
json_utils, message,
message::{ImageDetail, Text},
Embed, OneOrMany,
};
use async_stream::stream;
use futures::StreamExt;
use reqwest;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::convert::Infallible;
use std::{convert::TryFrom, str::FromStr};
const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
impl Client {
pub fn new() -> Self {
Self::from_url(OLLAMA_API_BASE_URL)
}
pub fn from_url(base_url: &str) -> Self {
Self {
base_url: base_url.to_owned(),
http_client: reqwest::Client::builder()
.build()
.expect("Ollama reqwest client should build"),
}
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path);
self.http_client.post(url)
}
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, 0)
}
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}
pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model))
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
pub const ALL_MINILM: &str = "all-minilm";
pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text";
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub model: String,
pub embeddings: Vec<Vec<f64>>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub load_duration: Option<u64>,
#[serde(default)]
pub prompt_eval_count: Option<u64>,
}
impl From<ApiErrorResponse> for EmbeddingError {
fn from(err: ApiErrorResponse) -> Self {
EmbeddingError::ProviderError(err.message)
}
}
impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
match value {
ApiResponse::Ok(response) => Ok(response),
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
pub model: String,
ndims: usize,
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_owned(),
ndims,
}
}
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
self.ndims
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let docs: Vec<String> = documents.into_iter().collect();
let payload = json!({
"model": self.model,
"input": docs,
});
let response = self
.client
.post("api/embed")
.json(&payload)
.send()
.await
.map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
if response.status().is_success() {
let api_resp: EmbeddingResponse = response
.json()
.await
.map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
if api_resp.embeddings.len() != docs.len() {
return Err(EmbeddingError::ResponseError(
"Number of returned embeddings does not match input".into(),
));
}
Ok(api_resp
.embeddings
.into_iter()
.zip(docs.into_iter())
.map(|(vec, document)| embeddings::Embedding { document, vec })
.collect())
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
pub const LLAMA3_2: &str = "llama3.2";
pub const LLAVA: &str = "llava";
pub const MISTRAL: &str = "mistral";
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionResponse {
pub model: String,
pub created_at: String,
pub message: Message,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub load_duration: Option<u64>,
#[serde(default)]
pub prompt_eval_count: Option<u64>,
#[serde(default)]
pub prompt_eval_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u64>,
#[serde(default)]
pub eval_duration: Option<u64>,
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(resp: CompletionResponse) -> Result<Self, Self::Error> {
match resp.message {
Message::Assistant {
content,
tool_calls,
..
} => {
let mut assistant_contents = Vec::new();
if !content.is_empty() {
assistant_contents.push(completion::AssistantContent::text(&content));
}
for tc in tool_calls.iter() {
assistant_contents.push(completion::AssistantContent::tool_call(
tc.function.name.clone(),
tc.function.name.clone(),
tc.function.arguments.clone(),
));
}
let choice = OneOrMany::many(assistant_contents).map_err(|_| {
CompletionError::ResponseError("No content provided".to_owned())
})?;
let raw_response = CompletionResponse {
model: resp.model,
created_at: resp.created_at,
done: resp.done,
done_reason: resp.done_reason,
total_duration: resp.total_duration,
load_duration: resp.load_duration,
prompt_eval_count: resp.prompt_eval_count,
prompt_eval_duration: resp.prompt_eval_duration,
eval_count: resp.eval_count,
eval_duration: resp.eval_duration,
message: Message::Assistant {
content,
images: None,
name: None,
tool_calls,
},
};
Ok(completion::CompletionResponse {
choice,
raw_response,
})
}
_ => Err(CompletionError::ResponseError(
"Chat response does not include an assistant message".into(),
)),
}
}
}
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_owned(),
}
}
fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
let prompt: Message = completion_request.prompt_with_context().try_into()?;
let options = if let Some(extra) = completion_request.additional_params {
json_utils::merge(
json!({ "temperature": completion_request.temperature }),
extra,
)
} else {
json!({ "temperature": completion_request.temperature })
};
let mut full_history = Vec::new();
if let Some(preamble) = completion_request.preamble {
full_history.push(Message::system(&preamble));
}
for msg in completion_request.chat_history.into_iter() {
full_history.push(Message::try_from(msg)?);
}
full_history.push(prompt);
let mut request_payload = json!({
"model": self.model,
"messages": full_history,
"options": options,
"stream": false,
});
if !completion_request.tools.is_empty() {
request_payload["tools"] = json!(completion_request
.tools
.into_iter()
.map(|tool| tool.into())
.collect::<Vec<ToolDefinition>>());
}
tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload);
Ok(request_payload)
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
let request_payload = self.create_completion_request(completion_request)?;
let response = self
.client
.post("api/chat")
.json(&request_payload)
.send()
.await
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
if response.status().is_success() {
let text = response
.text()
.await
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
tracing::debug!(target: "rig", "Ollama chat response: {}", text);
let chat_resp: CompletionResponse = serde_json::from_str(&text)
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
let conv: completion::CompletionResponse<CompletionResponse> = chat_resp.try_into()?;
Ok(conv)
} else {
let err_text = response
.text()
.await
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
Err(CompletionError::ProviderError(err_text))
}
}
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
let mut request_payload = self.create_completion_request(request)?;
merge_inplace(&mut request_payload, json!({"stream": true}));
let response = self
.client
.post("api/chat")
.json(&request_payload)
.send()
.await
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let err_text = response
.text()
.await
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
return Err(CompletionError::ProviderError(err_text));
}
Ok(Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let line = line.to_string();
let Ok(response) = serde_json::from_str::<CompletionResponse>(&line) else {
continue;
};
match response.message {
Message::Assistant{ content, tool_calls, .. } => {
if !content.is_empty() {
yield Ok(StreamingChoice::Message(content))
}
for tool_call in tool_calls.iter() {
let function = tool_call.function.clone();
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
}
}
_ => {
continue;
}
}
}
}
}))
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub type_field: String, pub function: completion::ToolDefinition,
}
impl From<crate::completion::ToolDefinition> for ToolDefinition {
fn from(tool: crate::completion::ToolDefinition) -> Self {
ToolDefinition {
type_field: "function".to_owned(),
function: completion::ToolDefinition {
name: tool.name,
description: tool.description,
parameters: tool.parameters,
},
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolCall {
#[serde(default, rename = "type")]
pub r#type: ToolType,
pub function: Function,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
#[default]
Function,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct Function {
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
User {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(default)]
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<ToolCall>,
},
System {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "Tool")]
ToolResult {
tool_call_id: String,
content: OneOrMany<ToolResultContent>,
},
}
impl TryFrom<crate::message::Message> for Message {
type Error = crate::message::MessageError;
fn try_from(internal_msg: crate::message::Message) -> Result<Self, Self::Error> {
use crate::message::Message as InternalMessage;
match internal_msg {
InternalMessage::User { content, .. } => {
let mut texts = Vec::new();
let mut images = Vec::new();
for uc in content.into_iter() {
match uc {
crate::message::UserContent::Text(t) => texts.push(t.text),
crate::message::UserContent::Image(img) => images.push(img.data),
_ => {} }
}
let content_str = texts.join(" ");
let images_opt = if images.is_empty() {
None
} else {
Some(images)
};
Ok(Message::User {
content: content_str,
images: images_opt,
name: None,
})
}
InternalMessage::Assistant { content, .. } => {
let mut texts = Vec::new();
let mut tool_calls = Vec::new();
for ac in content.into_iter() {
match ac {
crate::message::AssistantContent::Text(t) => texts.push(t.text),
crate::message::AssistantContent::ToolCall(tc) => {
tool_calls.push(ToolCall {
r#type: ToolType::Function, function: Function {
name: tc.function.name,
arguments: tc.function.arguments,
},
});
}
}
}
let content_str = texts.join(" ");
Ok(Message::Assistant {
content: content_str,
images: None,
name: None,
tool_calls,
})
}
}
}
}
impl From<Message> for crate::completion::Message {
fn from(msg: Message) -> Self {
match msg {
Message::User { content, .. } => crate::completion::Message::User {
content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
text: content,
})),
},
Message::Assistant {
content,
tool_calls,
..
} => {
let mut assistant_contents =
vec![crate::completion::message::AssistantContent::Text(Text {
text: content,
})];
for tc in tool_calls {
assistant_contents.push(
crate::completion::message::AssistantContent::tool_call(
tc.function.name.clone(),
tc.function.name,
tc.function.arguments,
),
);
}
crate::completion::Message::Assistant {
content: OneOrMany::many(assistant_contents).unwrap(),
}
}
Message::System { content, .. } => crate::completion::Message::User {
content: OneOrMany::one(crate::completion::message::UserContent::Text(Text {
text: content,
})),
},
Message::ToolResult {
tool_call_id,
content,
} => crate::completion::Message::User {
content: OneOrMany::one(message::UserContent::tool_result(
tool_call_id,
content.map(|content| message::ToolResultContent::text(content.text)),
)),
},
}
}
}
impl Message {
pub fn system(content: &str) -> Self {
Message::System {
content: content.to_owned(),
images: None,
name: None,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolResultContent {
text: String,
}
impl FromStr for ToolResultContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(s.to_owned().into())
}
}
impl From<String> for ToolResultContent {
fn from(s: String) -> Self {
ToolResultContent { text: s }
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SystemContent {
#[serde(default)]
r#type: SystemContentType,
text: String,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum SystemContentType {
#[default]
Text,
}
impl From<String> for SystemContent {
fn from(s: String) -> Self {
SystemContent {
r#type: SystemContentType::default(),
text: s,
}
}
}
impl FromStr for SystemContent {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(SystemContent {
r#type: SystemContentType::default(),
text: s.to_string(),
})
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct AssistantContent {
pub text: String,
}
impl FromStr for AssistantContent {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(AssistantContent { text: s.to_owned() })
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UserContent {
Text { text: String },
Image { image_url: ImageUrl },
}
impl FromStr for UserContent {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(UserContent::Text { text: s.to_owned() })
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ImageUrl {
pub url: String,
#[serde(default)]
pub detail: ImageDetail,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_chat_completion() {
let sample_chat_response = json!({
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": "The sky is blue because of Rayleigh scattering.",
"images": null,
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "San Francisco, CA",
"format": "celsius"
}
}
}
]
},
"done": true,
"total_duration": 8000000000u64,
"load_duration": 6000000u64,
"prompt_eval_count": 61u64,
"prompt_eval_duration": 400000000u64,
"eval_count": 468u64,
"eval_duration": 7700000000u64
});
let sample_text = sample_chat_response.to_string();
let chat_resp: CompletionResponse =
serde_json::from_str(&sample_text).expect("Invalid JSON structure");
let conv: completion::CompletionResponse<CompletionResponse> =
chat_resp.try_into().unwrap();
assert!(
!conv.choice.is_empty(),
"Expected non-empty choice in chat response"
);
}
#[test]
fn test_message_conversion() {
let provider_msg = Message::User {
content: "Test message".to_owned(),
images: None,
name: None,
};
let comp_msg: crate::completion::Message = provider_msg.into();
match comp_msg {
crate::completion::Message::User { content } => {
let first_content = content.first();
match first_content {
crate::completion::message::UserContent::Text(text_struct) => {
assert_eq!(text_struct.text, "Test message");
}
_ => panic!("Expected text content in conversion"),
}
}
_ => panic!("Conversion from provider Message to completion Message failed"),
}
}
#[test]
fn test_tool_definition_conversion() {
let internal_tool = crate::completion::ToolDefinition {
name: "get_current_weather".to_owned(),
description: "Get the current weather for a location".to_owned(),
parameters: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather for, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location", "format"]
}),
};
let ollama_tool: ToolDefinition = internal_tool.into();
assert_eq!(ollama_tool.type_field, "function");
assert_eq!(ollama_tool.function.name, "get_current_weather");
assert_eq!(
ollama_tool.function.description,
"Get the current weather for a location"
);
let params = &ollama_tool.function.parameters;
assert_eq!(params["properties"]["location"]["type"], "string");
}
}