use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io::{BufRead, BufReader};
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(default)]
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
#[serde(default)]
pub id: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: Value,
#[allow(dead_code)]
#[serde(default, skip_serializing)]
pub index: Option<u32>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
pub stream: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Value>,
pub options: Options,
}
#[derive(Debug, Clone, Serialize)]
pub struct Options {
pub temperature: f64,
pub num_ctx: u32,
}
#[derive(Debug, Deserialize)]
struct RawChunk {
message: RawMessage,
#[serde(default)]
done: bool,
}
#[derive(Debug, Deserialize)]
struct RawMessage {
#[serde(default)]
content: String,
#[serde(default)]
thinking: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Clone)]
pub struct Client {
base: String,
agent: ureq::Agent,
}
impl Client {
pub fn new(base_url: &str) -> Self {
let agent = ureq::AgentBuilder::new()
.timeout_connect(std::time::Duration::from_secs(10))
.timeout_read(std::time::Duration::from_secs(600))
.build();
Self {
base: base_url.trim_end_matches('/').to_string(),
agent,
}
}
pub fn is_healthy(&self) -> bool {
self.agent
.get(&format!("{}/api/tags", self.base))
.call()
.is_ok()
}
pub fn chat_stream<F>(
&self,
request: &ChatRequest,
show_thinking: bool,
cancel: Arc<AtomicBool>,
mut on_token: F,
) -> Result<(String, Option<Vec<ToolCall>>), String>
where
F: FnMut(&str, bool), {
let mut req = request.clone();
req.stream = true;
let url = format!("{}/api/chat", self.base);
let resp = match self.agent.post(&url).send_json(&req) {
Ok(r) => r,
Err(ureq::Error::Status(code, r)) => {
let body = r.into_string().unwrap_or_default();
return Err(format!("Ollama {code}: {body}"));
}
Err(e) => return Err(format!("Connection error: {e}")),
};
let reader = BufReader::new(resp.into_reader());
let mut content = String::new();
let mut tool_calls: Option<Vec<ToolCall>> = None;
for line in reader.lines() {
if cancel.load(Ordering::Relaxed) {
return Err("__cancelled__".into());
}
let line = line.map_err(|e| format!("Stream read error: {e}"))?;
if line.is_empty() {
continue;
}
let chunk: RawChunk = serde_json::from_str(&line)
.map_err(|e| format!("Stream parse error: {e}"))?;
if show_thinking {
if let Some(ref t) = chunk.message.thinking {
if !t.is_empty() {
on_token(t, true);
}
}
}
if !chunk.message.content.is_empty() {
on_token(&chunk.message.content.clone(), false);
content.push_str(&chunk.message.content);
}
if let Some(tc) = chunk.message.tool_calls {
if !tc.is_empty() {
tool_calls = Some(tc);
}
}
if chunk.done {
break;
}
}
Ok((content, tool_calls))
}
}