use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use ollama_rs::generation::completion::request::GenerationRequest;
use ollama_rs::Ollama;
use tokio::sync::RwLock;
use crate::debug_channel::{LlmDebugChannel, LlmDebugEvent};
use crate::decider::{
DecisionResponse, LlmDecider, LlmDeciderConfig, LlmError, LoraConfig, WorkerDecisionRequest,
};
use crate::prompt_builder::PromptBuilder;
use crate::response_parser;
#[derive(Debug, Clone)]
pub struct OllamaConfig {
pub base: LlmDeciderConfig,
pub num_predict: usize,
pub num_ctx: usize,
}
impl Default for OllamaConfig {
fn default() -> Self {
Self {
base: LlmDeciderConfig::default(),
num_predict: 256,
num_ctx: 2048,
}
}
}
pub struct OllamaDecider {
config: OllamaConfig,
client: Arc<RwLock<Ollama>>,
prompt_builder: PromptBuilder,
}
impl OllamaDecider {
pub fn new(config: OllamaConfig) -> Self {
let (host, port) = Self::parse_endpoint(&config.base.endpoint);
let client = Ollama::new(host, port);
Self {
config,
client: Arc::new(RwLock::new(client)),
prompt_builder: PromptBuilder::new(),
}
}
fn parse_endpoint(endpoint: &str) -> (String, u16) {
if let Some(pos) = endpoint.rfind(':') {
let host = &endpoint[..pos];
let port_str = &endpoint[pos + 1..];
if let Ok(port) = port_str.parse::<u16>() {
return (host.to_string(), port);
}
}
("http://localhost".to_string(), 11434)
}
async fn call_ollama(&self, prompt: &str) -> Result<(String, u64), LlmError> {
let start = Instant::now();
let client = self.client.read().await;
let request = GenerationRequest::new(self.config.base.model.clone(), prompt.to_string());
match client.generate(request).await {
Ok(response) => {
let latency_ms = start.elapsed().as_millis() as u64;
Ok((response.response, latency_ms))
}
Err(e) => {
tracing::warn!(error = %e, "Ollama API call failed");
Err(LlmError::transient(e.to_string()))
}
}
}
fn emit_debug_event(&self, event: LlmDebugEvent) {
LlmDebugChannel::global().emit(event);
}
}
impl Default for OllamaDecider {
fn default() -> Self {
Self::new(OllamaConfig::default())
}
}
impl LlmDecider for OllamaDecider {
fn decide(
&self,
request: WorkerDecisionRequest,
) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>> {
Box::pin(async move {
let prompt = self.prompt_builder.build(&request.context);
let worker_id = request.worker_id.0;
let (raw_response, latency_ms) = match self.call_ollama(&prompt).await {
Ok(result) => result,
Err(e) => {
self.emit_debug_event(
LlmDebugEvent::new("decide", &self.config.base.model)
.worker_id(worker_id)
.endpoint(&self.config.base.endpoint)
.prompt(&prompt)
.error(e.message()),
);
return Err(e);
}
};
let candidate_names = response_parser::candidate_names(&request.context.candidates);
match response_parser::parse_response(&raw_response, &candidate_names) {
Ok(mut decision) => {
self.emit_debug_event(
LlmDebugEvent::new("decide", &self.config.base.model)
.worker_id(worker_id)
.endpoint(&self.config.base.endpoint)
.prompt(&prompt)
.response(&raw_response)
.latency_ms(latency_ms),
);
decision.prompt = Some(prompt);
decision.raw_response = Some(raw_response);
Ok(decision)
}
Err(e) => {
self.emit_debug_event(
LlmDebugEvent::new("decide", &self.config.base.model)
.worker_id(worker_id)
.endpoint(&self.config.base.endpoint)
.prompt(&prompt)
.response(&raw_response)
.error(e.message())
.latency_ms(latency_ms),
);
tracing::warn!(error = %e, "Parse error");
tracing::debug!(raw = %raw_response, "Raw response");
Err(e)
}
}
})
}
fn call_raw(
&self,
prompt: &str,
_lora: Option<&LoraConfig>,
) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
let prompt = prompt.to_string();
Box::pin(async move {
match self.call_ollama(&prompt).await {
Ok((response, latency_ms)) => {
self.emit_debug_event(
LlmDebugEvent::new("call_raw", &self.config.base.model)
.endpoint(&self.config.base.endpoint)
.prompt(&prompt)
.response(&response)
.latency_ms(latency_ms),
);
Ok(response)
}
Err(e) => {
self.emit_debug_event(
LlmDebugEvent::new("call_raw", &self.config.base.model)
.endpoint(&self.config.base.endpoint)
.prompt(&prompt)
.error(e.message()),
);
Err(e)
}
}
})
}
fn model_name(&self) -> &str {
&self.config.base.model
}
fn endpoint(&self) -> &str {
&self.config.base.endpoint
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
Box::pin(async move {
let client = self.client.read().await;
client.list_local_models().await.is_ok()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decider::ActionCandidate;
use swarm_engine_core::agent::{ActionParam, ContextTarget, GlobalContext, ResolvedContext};
fn create_test_candidates() -> Vec<ActionCandidate> {
vec![
ActionCandidate {
name: "Read".to_string(),
description: "Read a file".to_string(),
params: vec![ActionParam {
name: "path".to_string(),
description: "File path".to_string(),
required: true,
}],
example: None,
},
ActionCandidate {
name: "Grep".to_string(),
description: "Search pattern".to_string(),
params: vec![ActionParam {
name: "pattern".to_string(),
description: "Search pattern".to_string(),
required: true,
}],
example: None,
},
]
}
fn create_test_context(candidates: Vec<ActionCandidate>) -> ResolvedContext {
let global = GlobalContext {
tick: 1,
max_ticks: 100,
progress: 0.5,
success_rate: 0.8,
task_description: Some("test task".to_string()),
hint: None,
};
ResolvedContext::new(
global,
ContextTarget::Worker(swarm_engine_core::types::WorkerId(0)),
)
.with_candidates(candidates)
}
#[test]
fn test_prompt_builder_integration() {
let decider = OllamaDecider::default();
let context = create_test_context(create_test_candidates());
let prompt = decider.prompt_builder.build(&context);
assert!(prompt.contains("JSON-only response AI"));
assert!(prompt.contains("Example interaction:"));
assert!(prompt.contains("## Task"));
assert!(prompt.contains("## Available Actions"));
assert!(prompt.contains("- Read: Read a file"));
assert!(prompt.contains("- Grep: Search pattern"));
assert!(prompt.contains("Your JSON:"));
}
#[test]
fn test_prompt_with_self_last_output() {
let decider = OllamaDecider::default();
let global = GlobalContext::new(10)
.with_max_ticks(100)
.with_progress(0.5)
.with_task("Find the bug");
let context = ResolvedContext::new(
global,
ContextTarget::Worker(swarm_engine_core::types::WorkerId(0)),
)
.with_self_last_output(Some("Found 3 files matching pattern".to_string()))
.with_candidates(create_test_candidates());
let prompt = decider.prompt_builder.build(&context);
assert!(prompt.contains("## Last Result"));
assert!(prompt.contains("Found 3 files matching pattern"));
assert!(!prompt.contains("## Your Status"));
assert!(!prompt.contains("## Team Status"));
}
}