use crate::client::{CompletionRequest, LlmClient, Role, TokenStream};
use crate::error::Error;
use async_stream::try_stream;
use async_trait::async_trait;
use futures::StreamExt;
pub struct OllamaClient {
client: reqwest::Client,
model: Option<String>,
pub(crate) base_url: String,
}
impl OllamaClient {
pub fn new(model: Option<String>, base_url: Option<String>) -> Self {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.expect("failed to build reqwest client");
let base_url = base_url.unwrap_or_else(|| "http://localhost:11434".to_string());
Self {
client,
model,
base_url,
}
}
pub(crate) fn embed_model() -> String {
std::env::var("FERRO_AI_EMBED_MODEL").unwrap_or_else(|_| "nomic-embed-text".to_string())
}
pub(crate) fn build_body(
&self,
request: &CompletionRequest,
stream: bool,
) -> serde_json::Value {
let model = request
.model_override
.as_deref()
.unwrap_or_else(|| self.default_model());
let mut messages: Vec<serde_json::Value> = Vec::new();
if let Some(system) = &request.system {
messages.push(serde_json::json!({
"role": "system",
"content": system,
}));
}
for m in &request.messages {
messages.push(serde_json::json!({
"role": match m.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
},
"content": m.content,
}));
}
serde_json::json!({
"model": model,
"messages": messages,
"stream": stream,
})
}
}
pub(crate) fn parse_ollama_line(line: &str) -> Result<(Option<String>, bool), Error> {
let v: serde_json::Value =
serde_json::from_str(line).map_err(|e| Error::Deserialization(e.to_string()))?;
let done = v["done"].as_bool().unwrap_or(false);
let token = v["message"]["content"]
.as_str()
.filter(|s| !s.is_empty())
.map(|s| s.to_string());
Ok((token, done))
}
pub(crate) fn parse_ollama_embedding(json: &serde_json::Value) -> Result<Vec<f32>, Error> {
json["embeddings"][0]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect()
})
.ok_or_else(|| Error::Deserialization("no embeddings in response".into()))
}
#[async_trait]
impl LlmClient for OllamaClient {
fn default_model(&self) -> &str {
self.model.as_deref().unwrap_or("llama3.1")
}
async fn complete(&self, request: CompletionRequest) -> Result<String, Error> {
let body = self.build_body(&request, false);
let resp = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
Error::Timeout
} else {
Error::Provider {
status: None,
message: e.to_string(),
}
}
})?;
let status = resp.status().as_u16();
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(Error::Provider {
status: Some(status),
message: text,
});
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| Error::Deserialization(e.to_string()))?;
json["message"]["content"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| Error::Deserialization("no content in response".into()))
}
async fn complete_stream(&self, request: CompletionRequest) -> Result<TokenStream, Error> {
let body = self.build_body(&request, true);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
Error::Timeout
} else {
Error::Provider {
status: None,
message: e.to_string(),
}
}
})?;
let status = response.status().as_u16();
if !response.status().is_success() {
let text = response.text().await.unwrap_or_default();
return Err(Error::Provider {
status: Some(status),
message: text,
});
}
let stream = Box::pin(try_stream! {
let mut bytes = response.bytes_stream();
let mut buf = String::new();
while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(|e| Error::Provider {
status: None,
message: e.to_string(),
})?;
buf.push_str(&String::from_utf8_lossy(&chunk));
while let Some(newline_pos) = buf.find('\n') {
let line = buf[..newline_pos].trim().to_string();
buf = buf[newline_pos + 1..].to_string();
if line.is_empty() {
continue;
}
let (token, done) = parse_ollama_line(&line)?;
if let Some(text) = token {
yield text;
}
if done {
return;
}
}
}
});
Ok(stream)
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, Error> {
let model = Self::embed_model();
let body = serde_json::json!({
"model": model,
"input": text,
});
let resp = self
.client
.post(format!("{}/api/embed", self.base_url))
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
Error::Timeout
} else {
Error::Provider {
status: None,
message: e.to_string(),
}
}
})?;
let status = resp.status().as_u16();
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(Error::Provider {
status: Some(status),
message: text,
});
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| Error::Deserialization(e.to_string()))?;
parse_ollama_embedding(&json)
}
}
#[cfg(test)]
mod tests {
use super::{parse_ollama_embedding, parse_ollama_line, OllamaClient};
use crate::client::LlmClient;
use crate::error::Error;
#[test]
fn test_ollama_default_model() {
let client = OllamaClient::new(None, None);
assert_eq!(client.default_model(), "llama3.1");
}
#[test]
fn test_ollama_model_override() {
let client = OllamaClient::new(Some("mistral".into()), None);
assert_eq!(client.default_model(), "mistral");
}
#[test]
fn test_ollama_default_base_url() {
let client = OllamaClient::new(None, None);
assert_eq!(client.base_url, "http://localhost:11434");
}
#[test]
fn test_parse_ollama_line_token() {
let line = r#"{"message":{"content":"The"},"done":false}"#;
let (token, done) = parse_ollama_line(line).unwrap();
assert_eq!(token, Some("The".to_string()));
assert!(!done);
}
#[test]
fn test_parse_ollama_line_done() {
let line = r#"{"message":{"content":""},"done":true}"#;
let (token, done) = parse_ollama_line(line).unwrap();
assert_eq!(token, None);
assert!(done);
}
#[test]
fn test_parse_ollama_embedding() {
let json = serde_json::json!({
"embeddings": [[0.1_f64, -0.2_f64]],
"total_duration": 12345
});
let result = parse_ollama_embedding(&json).unwrap();
assert_eq!(result.len(), 2);
assert!((result[0] - 0.1f32).abs() < 1e-6);
assert!((result[1] - (-0.2f32)).abs() < 1e-6);
}
#[test]
fn test_parse_ollama_embedding_missing() {
let json = serde_json::json!({"embeddings": []});
assert!(matches!(
parse_ollama_embedding(&json),
Err(Error::Deserialization(_))
));
}
#[test]
fn test_ollama_is_object_safe() {
let _: Box<dyn LlmClient> = Box::new(OllamaClient::new(None, None));
}
#[test]
fn embed_model_default_is_nomic() {
let _g = crate::ENV_LOCK.lock().unwrap();
std::env::remove_var("FERRO_AI_EMBED_MODEL");
assert_eq!(OllamaClient::embed_model(), "nomic-embed-text");
}
#[test]
fn embed_model_from_env() {
let _g = crate::ENV_LOCK.lock().unwrap();
std::env::set_var("FERRO_AI_EMBED_MODEL", "mxbai-embed-large");
assert_eq!(OllamaClient::embed_model(), "mxbai-embed-large");
std::env::remove_var("FERRO_AI_EMBED_MODEL");
}
}