llm-cascade 0.1.0

Resilient cascading LLM inference with automatic failover across multiple providers
Documentation
//! Ollama local inference provider.

use reqwest::Client;
use serde_json::{json, Value};

use crate::error::ProviderError;
use crate::models::{ContentBlock, Conversation, LlmResponse, MessageRole};
use crate::providers::LlmProvider;

/// Provider for a local Ollama inference server. No API key required.
pub struct OllamaProvider {
    client: Client,
    base_url: String,
    model: String,
}

impl OllamaProvider {
    /// Creates a new Ollama provider.
    pub fn new(model: String, base_url: Option<String>) -> Self {
        Self {
            client: Client::new(),
            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
            model,
        }
    }
}

#[async_trait::async_trait]
impl LlmProvider for OllamaProvider {
    async fn complete(&self, conversation: &Conversation) -> Result<LlmResponse, ProviderError> {
        let mut messages = Vec::new();

        for msg in &conversation.messages {
            let role = match msg.role {
                MessageRole::System => "system",
                MessageRole::User => "user",
                MessageRole::Assistant => "assistant",
                MessageRole::Tool => "tool",
            };

            let mut message = json!({
                "role": role,
                "content": msg.content,
            });

            if let Some(ref tool_call_id) = msg.tool_call_id {
                message["tool_call_id"] = json!(tool_call_id);
            }

            messages.push(message);
        }

        let mut body = json!({
            "model": self.model,
            "messages": messages,
            "stream": false,
        });

        if let Some(ref tools) = conversation.tools {
            let ollama_tools: Vec<Value> = tools.iter().map(|t| {
                json!({
                    "type": "function",
                    "function": {
                        "name": t.name,
                        "description": t.description,
                        "parameters": t.parameters,
                    }
                })
            }).collect();
            body["tools"] = json!(ollama_tools);
        }

        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));

        let response = self.client.post(&url).json(&body).send().await?;

        let status = response.status();
        let retry_after = response
            .headers()
            .get("retry-after")
            .and_then(|v| v.to_str().ok())
            .and_then(|v| v.parse::<u64>().ok());

        if !status.is_success() {
            let error_body = response.text().await.unwrap_or_default();
            return Err(ProviderError::Http {
                status: status.as_u16(),
                body: error_body,
                retry_after,
            });
        }

        let data: Value = response.json().await.map_err(|e| ProviderError::Parse(e.to_string()))?;

        let message = &data["message"];
        let mut content_blocks = Vec::new();

        if let Some(content) = message["content"].as_str()
            && !content.is_empty()
        {
            content_blocks.push(ContentBlock::Text { text: content.to_string() });
        }

        if let Some(tool_calls) = message["tool_calls"].as_array() {
            for tc in tool_calls {
                let function = &tc["function"];
                let name = function["name"].as_str().unwrap_or("").to_string();
                let arguments = function["arguments"].to_string();
                let id = tc["id"].as_str().unwrap_or("").to_string();
                content_blocks.push(ContentBlock::ToolCall { id, name, arguments });
            }
        }

        let eval_count = data.get("eval_count").and_then(|v| v.as_u64()).map(|v| v as u32);
        let prompt_eval_count = data.get("prompt_eval_count").and_then(|v| v.as_u64()).map(|v| v as u32);

        Ok(LlmResponse {
            content: content_blocks,
            input_tokens: prompt_eval_count,
            output_tokens: eval_count,
            model: self.model.clone(),
        })
    }

    fn provider_name(&self) -> &str {
        "ollama"
    }

    fn model_name(&self) -> &str {
        &self.model
    }
}