Skip to main content

graphrag_core/ollama/
mod.rs

1//! Ollama LLM integration
2//!
3//! This module provides integration with Ollama for local LLM inference.
4
5use crate::core::{GraphRAGError, Result};
6
7/// Ollama configuration
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct OllamaConfig {
10    /// Enable Ollama integration
11    pub enabled: bool,
12    /// Ollama host URL
13    pub host: String,
14    /// Ollama port
15    pub port: u16,
16    /// Model for embeddings
17    pub embedding_model: String,
18    /// Model for chat/generation
19    pub chat_model: String,
20    /// Timeout in seconds
21    pub timeout_seconds: u64,
22    /// Maximum retry attempts
23    pub max_retries: u32,
24    /// Fallback to hash-based IDs on error
25    pub fallback_to_hash: bool,
26    /// Maximum tokens to generate
27    pub max_tokens: Option<u32>,
28    /// Temperature for generation (0.0 - 1.0)
29    pub temperature: Option<f32>,
30}
31
32impl Default for OllamaConfig {
33    fn default() -> Self {
34        Self {
35            enabled: false,
36            host: "http://localhost".to_string(),
37            port: 11434,
38            embedding_model: "nomic-embed-text".to_string(),
39            chat_model: "llama3.2:3b".to_string(),
40            timeout_seconds: 30,
41            max_retries: 3,
42            fallback_to_hash: true,
43            max_tokens: Some(2000),
44            temperature: Some(0.7),
45        }
46    }
47}
48
49/// Ollama client for LLM inference
50#[derive(Debug, Clone)]
51pub struct OllamaClient {
52    config: OllamaConfig,
53    #[cfg(feature = "ureq")]
54    client: ureq::Agent,
55}
56
57impl OllamaClient {
58    /// Create a new Ollama client
59    pub fn new(config: OllamaConfig) -> Self {
60        Self {
61            config: config.clone(),
62            #[cfg(feature = "ureq")]
63            client: ureq::AgentBuilder::new()
64                .timeout(std::time::Duration::from_secs(config.timeout_seconds))
65                .build(),
66        }
67    }
68
69    /// Generate text completion using Ollama API
70    #[cfg(feature = "ureq")]
71    pub async fn generate(&self, prompt: &str) -> Result<String> {
72        let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
73
74        let mut request_body = serde_json::json!({
75            "model": self.config.chat_model,
76            "prompt": prompt,
77            "stream": false,
78        });
79
80        // Add optional parameters
81        if let Some(max_tokens) = self.config.max_tokens {
82            request_body["options"] = serde_json::json!({
83                "num_predict": max_tokens,
84            });
85        }
86
87        if let Some(temperature) = self.config.temperature {
88            if request_body.get("options").is_none() {
89                request_body["options"] = serde_json::json!({});
90            }
91            request_body["options"]["temperature"] = serde_json::json!(temperature);
92        }
93
94        // Make HTTP request with retry logic
95        let mut last_error = None;
96        for attempt in 1..=self.config.max_retries {
97            match self.client
98                .post(&endpoint)
99                .set("Content-Type", "application/json")
100                .send_json(&request_body)
101            {
102                Ok(response) => {
103                    let json_response: serde_json::Value = response
104                        .into_json()
105                        .map_err(|e| GraphRAGError::Generation {
106                            message: format!("Failed to parse JSON response: {}", e),
107                        })?;
108
109                    // Extract response text
110                    if let Some(response_text) = json_response["response"].as_str() {
111                        return Ok(response_text.to_string());
112                    } else {
113                        return Err(GraphRAGError::Generation {
114                            message: format!("Invalid response format: {:?}", json_response),
115                        });
116                    }
117                }
118                Err(e) => {
119                    tracing::warn!("Ollama API request failed (attempt {}): {}", attempt, e);
120                    last_error = Some(e);
121
122                    if attempt < self.config.max_retries {
123                        // Wait before retry (exponential backoff)
124                        tokio::time::sleep(std::time::Duration::from_millis(100 * attempt as u64)).await;
125                    }
126                }
127            }
128        }
129
130        Err(GraphRAGError::Generation {
131            message: format!("Ollama API failed after {} retries: {:?}",
132                           self.config.max_retries, last_error),
133        })
134    }
135
136    /// Generate text completion (sync fallback when ureq feature is disabled)
137    #[cfg(not(feature = "ureq"))]
138    pub async fn generate(&self, _prompt: &str) -> Result<String> {
139        Err(GraphRAGError::Generation {
140            message: "ureq feature required for Ollama integration".to_string(),
141        })
142    }
143}