use std::env;
use crate::{http_config::build_client, preview};
use async_trait::async_trait;
use futures_util::{StreamExt, stream::BoxStream};
use reqwest::header::HeaderMap;
use rucora_core::{
error::ProviderError,
provider::{
LlmProvider,
types::{
ChatMessage, ChatRequest, ChatResponse, ChatStreamChunk, FinishReason, ResponseFormat,
Role, Usage,
},
},
};
use serde_json::{Value, json};
use tracing::debug;
pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.1:8b";
#[derive(Clone)]
pub struct OllamaProvider {
client: reqwest::Client,
base_url: String,
default_model: String,
}
impl OllamaProvider {
fn map_reqwest_error(e: reqwest::Error) -> ProviderError {
if e.is_timeout() {
ProviderError::Timeout {
message: e.to_string(),
elapsed: std::time::Duration::ZERO,
}
} else if e.is_connect() || e.is_request() {
ProviderError::Network {
message: e.to_string(),
source: Some(Box::new(e)),
retriable: true,
}
} else {
ProviderError::Message(e.to_string())
}
}
fn map_http_error(status: reqwest::StatusCode, message: String) -> ProviderError {
match status.as_u16() {
401 | 403 => ProviderError::Authentication { message },
429 => ProviderError::RateLimit {
message,
retry_after: None,
},
status => ProviderError::Api {
status,
message,
code: None,
},
}
}
pub fn from_env() -> Self {
let base_url =
env::var("OLLAMA_BASE_URL").unwrap_or_else(|_| "http://localhost:11434".to_string());
let default_model =
env::var("OLLAMA_DEFAULT_MODEL").unwrap_or_else(|_| OLLAMA_DEFAULT_MODEL.to_string());
Self::with_model(base_url, default_model)
}
pub fn new(base_url: impl Into<String>) -> Self {
Self::with_model(base_url, OLLAMA_DEFAULT_MODEL.to_string())
}
pub fn with_model(base_url: impl Into<String>, default_model: impl Into<String>) -> Self {
let headers = HeaderMap::new();
let client = build_client(headers);
Self {
client,
base_url: base_url.into(),
default_model: default_model.into(),
}
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = model.into();
self
}
pub fn default_model(&self) -> &str {
&self.default_model
}
fn map_role(role: &Role) -> &'static str {
match role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
fn build_messages(messages: &[ChatMessage]) -> Vec<Value> {
messages
.iter()
.map(|m| {
let mut obj = json!({
"role": Self::map_role(&m.role),
"content": m.content,
});
if let Some(name) = &m.name
&& let Some(map) = obj.as_object_mut()
{
map.insert("name".to_string(), Value::String(name.clone()));
}
obj
})
.collect()
}
}
#[async_trait]
impl LlmProvider for OllamaProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
let model = request
.model
.clone()
.unwrap_or_else(|| self.default_model.clone());
let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
let mut body = json!({
"model": model,
"messages": Self::build_messages(&request.messages),
"stream": false
});
if let Some(map) = body.as_object_mut() {
if let Some(temperature) = request.temperature {
map.insert("temperature".to_string(), json!(temperature));
}
if let Some(top_p) = request.top_p {
map.insert("top_p".to_string(), json!(top_p));
}
if let Some(top_k) = request.top_k {
map.insert("top_k".to_string(), json!(top_k));
}
if let Some(max_tokens) = request.max_tokens {
map.insert("max_tokens".to_string(), json!(max_tokens));
}
if let Some(frequency_penalty) = request.frequency_penalty {
map.insert("frequency_penalty".to_string(), json!(frequency_penalty));
}
if let Some(presence_penalty) = request.presence_penalty {
map.insert("presence_penalty".to_string(), json!(presence_penalty));
}
if let Some(stop) = request.stop
&& !stop.is_empty()
{
map.insert("stop".to_string(), json!(stop));
}
}
if let Some(tools) = &request.tools
&& let Some(map) = body.as_object_mut()
{
let tools_array: Vec<Value> = tools
.iter()
.map(|tool_def| {
json!({
"type": "function",
"function": {
"name": tool_def.name,
"description": tool_def.description.as_deref().unwrap_or(""),
"parameters": tool_def.input_schema
}
})
})
.collect();
map.insert("tools".to_string(), json!(tools_array));
}
if let Some(fmt) = request.response_format.as_ref() {
match fmt {
ResponseFormat::JsonObject => {
if let Some(map) = body.as_object_mut() {
map.insert("format".to_string(), json!("json"));
}
}
ResponseFormat::JsonSchema { .. } => {
return Err(ProviderError::Message(
"Ollama provider 暂不支持 JsonSchema 结构化输出".to_string(),
));
}
}
}
if let Some(extra) = &request.extra
&& let Some(map) = body.as_object_mut()
&& let Value::Object(extra_map) = extra
{
for (key, value) in extra_map {
map.insert(key.clone(), value.clone());
}
}
let last_user_preview = request
.messages
.iter()
.rev()
.find(|m| m.role == Role::User)
.map(|m| preview(&m.content, 600));
debug!(
provider = "ollama",
url = %url,
model = %body.get("model").and_then(|v| v.as_str()).unwrap_or(""),
messages_len = request.messages.len(),
tools_len = request.tools.as_ref().map_or(0, |t| t.len()),
last_user = last_user_preview.as_deref().unwrap_or(""),
"provider.chat.start"
);
debug!(provider = "ollama", body = %preview(&body.to_string(), 1200), "provider.chat.request_body");
let start = std::time::Instant::now();
let resp = self
.client
.post(url)
.json(&body)
.send()
.await
.map_err(Self::map_reqwest_error)?;
let status = resp.status();
let data: Value = resp.json().await.map_err(Self::map_reqwest_error)?;
let elapsed_ms = start.elapsed().as_millis() as u64;
debug!(provider = "ollama", status = %status, elapsed_ms, "provider.chat.http.done");
debug!(provider = "ollama", status = %status, body = %preview(&data.to_string(), 1200), "provider.chat.response_body");
if !status.is_success() {
return Err(Self::map_http_error(
status,
format!("Ollama 请求失败:status={status} body={data}"),
));
}
let content = data
.get("message")
.and_then(|m| m.get("content"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let tool_calls = data
.get("message")
.and_then(|m| m.get("tool_calls"))
.and_then(|tc| tc.as_array())
.map(|arr| {
arr.iter()
.filter_map(|tc| {
let function = tc.get("function")?;
let name = function.get("name")?.as_str()?.to_string();
let input = match function.get("arguments") {
Some(value @ Value::Object(_)) => value.clone(),
Some(Value::String(s)) => {
serde_json::from_str(s).unwrap_or_else(|_| json!({"_raw": s}))
}
_ => json!({}),
};
let id = tc
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if name.is_empty() {
None
} else {
Some(rucora_core::tool::types::ToolCall { id, name, input })
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
debug!(
provider = "ollama",
assistant_content_len = content.len(),
tool_calls_len = tool_calls.len(),
"provider.chat.parsed"
);
let usage = data.get("usage").map(|u| Usage {
prompt_tokens: u
.get("prompt_tokens")
.and_then(|v| v.as_u64())
.map_or(0, |v| v as u32),
completion_tokens: u
.get("completion_tokens")
.and_then(|v| v.as_u64())
.map_or(0, |v| v as u32),
total_tokens: u
.get("total_tokens")
.and_then(|v| v.as_u64())
.map_or(0, |v| v as u32),
});
let finish_reason = data
.get("done_reason")
.and_then(|v| v.as_str())
.map(|fr| match fr {
"stop" => FinishReason::Stop,
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCall,
_ => FinishReason::Other,
});
Ok(ChatResponse {
message: ChatMessage {
role: Role::Assistant,
content,
name: None,
},
tool_calls,
usage,
finish_reason,
})
}
fn stream_chat(
&self,
request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatStreamChunk, ProviderError>>, ProviderError> {
let model = request
.model
.clone()
.unwrap_or_else(|| self.default_model.clone());
let client = self.client.clone();
let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
let mut body = json!({
"model": model,
"messages": Self::build_messages(&request.messages),
"stream": true
});
if let Some(map) = body.as_object_mut() {
if let Some(temperature) = request.temperature {
map.insert("temperature".to_string(), json!(temperature));
}
if let Some(top_p) = request.top_p {
map.insert("top_p".to_string(), json!(top_p));
}
if let Some(top_k) = request.top_k {
map.insert("top_k".to_string(), json!(top_k));
}
if let Some(max_tokens) = request.max_tokens {
map.insert("max_tokens".to_string(), json!(max_tokens));
}
if let Some(frequency_penalty) = request.frequency_penalty {
map.insert("frequency_penalty".to_string(), json!(frequency_penalty));
}
if let Some(presence_penalty) = request.presence_penalty {
map.insert("presence_penalty".to_string(), json!(presence_penalty));
}
if let Some(stop) = request.stop
&& !stop.is_empty()
{
map.insert("stop".to_string(), json!(stop));
}
}
if let Some(fmt) = request.response_format.as_ref() {
match fmt {
ResponseFormat::JsonObject => {
if let Some(map) = body.as_object_mut() {
map.insert("format".to_string(), json!("json"));
}
}
ResponseFormat::JsonSchema { .. } => {
return Err(ProviderError::Message(
"Ollama provider 暂不支持 JsonSchema 结构化输出".to_string(),
));
}
}
}
if let Some(extra) = &request.extra
&& let Some(map) = body.as_object_mut()
&& let Value::Object(extra_map) = extra
{
for (key, value) in extra_map {
map.insert(key.clone(), value.clone());
}
}
debug!(
provider = "ollama",
url = %url,
model = %body.get("model").and_then(|v| v.as_str()).unwrap_or(""),
messages_len = request.messages.len(),
"provider.stream_chat.start"
);
debug!(provider = "ollama", body = %preview(&body.to_string(), 1200), "provider.stream_chat.request_body");
let stream = async_stream::try_stream! {
let start = std::time::Instant::now();
let resp = client
.post(url)
.json(&body)
.send()
.await
.map_err(Self::map_reqwest_error)?;
let status = resp.status();
if !status.is_success() {
Err(Self::map_http_error(
status,
format!("Ollama stream 请求失败:status={status}"),
))?;
}
debug!(
provider = "ollama",
status = %status,
elapsed_ms = start.elapsed().as_millis() as u64,
"provider.stream_chat.http.started"
);
let mut buf = String::new();
let mut bytes_stream = resp.bytes_stream();
while let Some(item) = bytes_stream.next().await {
let bytes = item.map_err(Self::map_reqwest_error)?;
let chunk = String::from_utf8_lossy(&bytes);
buf.push_str(&chunk);
while let Some(idx) = buf.find('\n') {
let line = buf[..idx].trim().to_string();
buf = buf[idx + 1..].to_string();
if line.is_empty() {
continue;
}
let v: Value = serde_json::from_str(&line)
.map_err(|e| ProviderError::Message(format!("NDJSON 解析失败: {e} line={line}")))?;
let done = v.get("done").and_then(|d| d.as_bool()).unwrap_or(false);
let delta = v
.get("message")
.and_then(|m| m.get("content"))
.and_then(|s| s.as_str())
.map(|s| s.to_string());
if delta.is_some() {
yield ChatStreamChunk {
delta,
tool_calls: vec![],
usage: None,
finish_reason: None,
};
}
if done {
break;
}
}
}
};
Ok(Box::pin(stream))
}
}