Skip to main content

graphrag_core/ollama/
mod.rs

1//! Ollama LLM integration
2//!
3//! This module provides integration with Ollama for local LLM inference.
4
5use crate::core::{GraphRAGError, Result};
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9/// Generation parameters for Ollama requests
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct OllamaGenerationParams {
12    /// Maximum tokens to generate
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub num_predict: Option<u32>,
15    /// Temperature for sampling (0.0 - 1.0)
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub temperature: Option<f32>,
18    /// Top-p nucleus sampling threshold
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub top_p: Option<f32>,
21    /// Top-k sampling
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub top_k: Option<u32>,
24    /// Stop sequences
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub stop: Option<Vec<String>>,
27    /// Repeat penalty
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub repeat_penalty: Option<f32>,
30    /// Context window size in tokens.
31    ///
32    /// **Critical for long documents**: Ollama silently truncates prompts that exceed
33    /// the default context size (often 2k-8k tokens). Set this to accommodate the
34    /// full document + chunk + instructions when using Contextual Enrichment.
35    ///
36    /// For KV Cache efficiency, calculate as:
37    /// `tokens(instructions) + tokens(document) + tokens(max_chunk) + output_tokens + 5% margin`
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub num_ctx: Option<u32>,
40    /// How long to keep the model loaded in memory after the request (e.g. "1h", "30m", "0").
41    ///
42    /// **Critical for KV Cache**: Without this, Ollama may unload the model between
43    /// consecutive requests, destroying the KV cache and forcing a full re-evaluation
44    /// of the static document prefix for every chunk. Set to "1h" when processing
45    /// multiple chunks from the same document.
46    ///
47    /// This is a top-level Ollama API field, not an option — serialized separately.
48    #[serde(skip)]
49    pub keep_alive: Option<String>,
50
51    /// KV cache context from a previous `/api/generate` response.
52    ///
53    /// When set, the model **continues from this token state** instead of re-evaluating
54    /// the entire prompt. Use this for the two-step KV cache pattern:
55    ///
56    /// 1. **Prime**: send the full document, get `context` back (loads doc into KV cache)
57    /// 2. **Per chunk**: send only the chunk text with the priming `context`
58    ///    → Ollama skips document re-evaluation, only evaluates ~128 chunk tokens
59    ///
60    /// This is a top-level Ollama API field — serialized separately.
61    #[serde(skip)]
62    pub context: Option<Vec<i64>>,
63}
64
65impl Default for OllamaGenerationParams {
66    fn default() -> Self {
67        Self {
68            num_predict: Some(2000),
69            temperature: Some(0.7),
70            top_p: Some(0.9),
71            top_k: Some(40),
72            stop: None,
73            repeat_penalty: Some(1.1),
74            num_ctx: None,
75            keep_alive: None,
76            context: None,
77        }
78    }
79}
80
81/// Full response from `/api/generate`, including KV cache context and token stats.
82///
83/// Used by [`OllamaClient::generate_with_full_response`] to support the two-step
84/// KV cache pattern (prime with document, then enrich each chunk cheaply).
85#[derive(Debug, Clone)]
86pub struct OllamaGenerateResponse {
87    /// The generated text
88    pub text: String,
89    /// KV cache token state — pass back as `OllamaGenerationParams::context` on the
90    /// next request to continue from this exact point without re-evaluating prior tokens.
91    pub context: Vec<i64>,
92    /// Tokens actually evaluated in the prompt (vs reused from KV cache).
93    /// With KV cache working: ~= chunk_tokens.  Without: ~= full_prompt_tokens.
94    pub prompt_eval_count: u64,
95    /// Tokens generated in the response.
96    pub eval_count: u64,
97}
98
99/// Usage statistics for Ollama client
100#[derive(Debug, Clone, Default)]
101pub struct OllamaUsageStats {
102    /// Total number of requests
103    pub total_requests: Arc<AtomicU64>,
104    /// Total number of successful requests
105    pub successful_requests: Arc<AtomicU64>,
106    /// Total number of failed requests
107    pub failed_requests: Arc<AtomicU64>,
108    /// Total tokens generated (approximate)
109    pub total_tokens: Arc<AtomicU64>,
110}
111
112impl OllamaUsageStats {
113    /// Create new usage statistics
114    pub fn new() -> Self {
115        Self::default()
116    }
117
118    /// Record a successful request
119    pub fn record_success(&self, tokens: u64) {
120        self.total_requests.fetch_add(1, Ordering::Relaxed);
121        self.successful_requests.fetch_add(1, Ordering::Relaxed);
122        self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
123    }
124
125    /// Record a failed request
126    pub fn record_failure(&self) {
127        self.total_requests.fetch_add(1, Ordering::Relaxed);
128        self.failed_requests.fetch_add(1, Ordering::Relaxed);
129    }
130
131    /// Get total requests
132    pub fn get_total_requests(&self) -> u64 {
133        self.total_requests.load(Ordering::Relaxed)
134    }
135
136    /// Get successful requests
137    pub fn get_successful_requests(&self) -> u64 {
138        self.successful_requests.load(Ordering::Relaxed)
139    }
140
141    /// Get failed requests
142    pub fn get_failed_requests(&self) -> u64 {
143        self.failed_requests.load(Ordering::Relaxed)
144    }
145
146    /// Get total tokens
147    pub fn get_total_tokens(&self) -> u64 {
148        self.total_tokens.load(Ordering::Relaxed)
149    }
150
151    /// Get success rate (0.0 - 1.0)
152    pub fn get_success_rate(&self) -> f64 {
153        let total = self.get_total_requests();
154        if total == 0 {
155            return 0.0;
156        }
157        self.get_successful_requests() as f64 / total as f64
158    }
159}
160
161/// Ollama configuration
162#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
163pub struct OllamaConfig {
164    /// Enable Ollama integration
165    pub enabled: bool,
166    /// Ollama host URL
167    pub host: String,
168    /// Ollama port
169    pub port: u16,
170    /// Model for embeddings
171    pub embedding_model: String,
172    /// Model for chat/generation
173    pub chat_model: String,
174    /// Timeout in seconds
175    pub timeout_seconds: u64,
176    /// Maximum retry attempts
177    pub max_retries: u32,
178    /// Fallback to hash-based IDs on error
179    pub fallback_to_hash: bool,
180    /// Maximum tokens to generate
181    pub max_tokens: Option<u32>,
182    /// Temperature for generation (0.0 - 1.0)
183    pub temperature: Option<f32>,
184    /// Enable model caching
185    pub enable_caching: bool,
186    /// How long to keep the model loaded in memory between requests (e.g. "1h", "30m", "0").
187    ///
188    /// Without this, Ollama may unload the model between requests, destroying the KV cache
189    /// and forcing full re-evaluation of long document contexts on every request.
190    /// Set to "1h" when processing multiple chunks from the same document.
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub keep_alive: Option<String>,
193    /// Default context window size for generation requests.
194    ///
195    /// Ollama silently truncates prompts exceeding this value (default is often 2048-8192).
196    /// For long-document processing, set this to at least:
197    /// `tokens(document) + tokens(max_chunk) + tokens(instructions) + 150 output tokens`
198    /// Use `None` to let Ollama use its model default.
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub num_ctx: Option<u32>,
201}
202
203impl Default for OllamaConfig {
204    fn default() -> Self {
205        Self {
206            enabled: false,
207            host: "http://localhost".to_string(),
208            port: 11434,
209            embedding_model: "nomic-embed-text".to_string(),
210            chat_model: "llama3.2:3b".to_string(),
211            timeout_seconds: 30,
212            max_retries: 3,
213            fallback_to_hash: true,
214            max_tokens: Some(2000),
215            temperature: Some(0.7),
216            enable_caching: true,
217            keep_alive: None,
218            num_ctx: None,
219        }
220    }
221}
222
223/// Ollama client for LLM inference
224#[derive(Clone)]
225pub struct OllamaClient {
226    config: OllamaConfig,
227    #[cfg(feature = "ureq")]
228    client: ureq::Agent,
229    /// Usage statistics
230    stats: OllamaUsageStats,
231    /// Response cache (prompt -> response)
232    #[cfg(feature = "dashmap")]
233    cache: Arc<dashmap::DashMap<String, String>>,
234}
235
236impl std::fmt::Debug for OllamaClient {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("OllamaClient")
239            .field("config", &self.config)
240            .field("stats", &self.stats)
241            .finish()
242    }
243}
244
245impl OllamaClient {
246    /// Create a new Ollama client
247    pub fn new(config: OllamaConfig) -> Self {
248        Self {
249            config: config.clone(),
250            #[cfg(feature = "ureq")]
251            client: ureq::AgentBuilder::new()
252                .timeout(std::time::Duration::from_secs(config.timeout_seconds))
253                .build(),
254            stats: OllamaUsageStats::new(),
255            #[cfg(feature = "dashmap")]
256            cache: Arc::new(dashmap::DashMap::new()),
257        }
258    }
259
260    /// Get usage statistics
261    pub fn get_stats(&self) -> &OllamaUsageStats {
262        &self.stats
263    }
264
265    /// Access the underlying Ollama configuration
266    pub fn config(&self) -> &OllamaConfig {
267        &self.config
268    }
269
270    /// Clear the cache
271    #[cfg(feature = "dashmap")]
272    pub fn clear_cache(&self) {
273        self.cache.clear();
274    }
275
276    /// Get cache size
277    #[cfg(feature = "dashmap")]
278    pub fn cache_size(&self) -> usize {
279        self.cache.len()
280    }
281
282    /// Generate text completion using Ollama API
283    #[cfg(feature = "ureq")]
284    pub async fn generate(&self, prompt: &str) -> Result<String> {
285        // Check cache first if enabled
286        #[cfg(feature = "dashmap")]
287        {
288            if self.config.enable_caching {
289                if let Some(cached_response) = self.cache.get(prompt) {
290                    #[cfg(feature = "tracing")]
291                    tracing::debug!("Cache hit for prompt (length: {})", prompt.len());
292                    return Ok(cached_response.clone());
293                }
294            }
295        }
296
297        // Use default parameters
298        let params = OllamaGenerationParams {
299            num_predict: self.config.max_tokens,
300            temperature: self.config.temperature,
301            ..Default::default()
302        };
303
304        self.generate_with_params(prompt, params).await
305    }
306
307    /// Generate text completion with custom parameters
308    #[cfg(feature = "ureq")]
309    pub async fn generate_with_params(
310        &self,
311        prompt: &str,
312        params: OllamaGenerationParams,
313    ) -> Result<String> {
314        let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
315
316        // Extract keep_alive before serializing params (it's a top-level field, not an option)
317        let keep_alive = params
318            .keep_alive
319            .clone()
320            .or_else(|| self.config.keep_alive.clone());
321
322        let mut request_body = serde_json::json!({
323            "model": self.config.chat_model,
324            "prompt": prompt,
325            "stream": false,
326        });
327
328        // keep_alive is a top-level field (controls model unloading between requests)
329        if let Some(ref ka) = keep_alive {
330            request_body["keep_alive"] = serde_json::Value::String(ka.clone());
331        }
332
333        // context is a top-level field: KV cache token state from a previous response.
334        // When set, the model continues from this state, skipping re-evaluation of prior tokens.
335        if let Some(ref ctx) = params.context {
336            request_body["context"] = serde_json::Value::Array(
337                ctx.iter()
338                    .map(|&t| serde_json::Value::Number(t.into()))
339                    .collect(),
340            );
341        }
342
343        // Build options object: serialized params + num_ctx
344        let mut options = serde_json::to_value(&params).map_err(|e| GraphRAGError::Generation {
345            message: format!("Failed to serialize generation params: {}", e),
346        })?;
347
348        // Add num_ctx to options (overrides config default if set in params)
349        let effective_num_ctx = params.num_ctx.or(self.config.num_ctx);
350        if let Some(num_ctx) = effective_num_ctx {
351            if let Some(obj) = options.as_object_mut() {
352                obj.insert(
353                    "num_ctx".to_string(),
354                    serde_json::Value::Number(num_ctx.into()),
355                );
356            }
357        }
358
359        if !options.as_object().map_or(true, |o| o.is_empty()) {
360            request_body["options"] = options;
361        }
362
363        // Make HTTP request with retry logic
364        let mut last_error = None;
365        for attempt in 1..=self.config.max_retries {
366            match self
367                .client
368                .post(&endpoint)
369                .set("Content-Type", "application/json")
370                .send_json(&request_body)
371            {
372                Ok(response) => {
373                    let json_response: serde_json::Value =
374                        response
375                            .into_json()
376                            .map_err(|e| GraphRAGError::Generation {
377                                message: format!("Failed to parse JSON response: {}", e),
378                            })?;
379
380                    // Extract response text
381                    if let Some(response_text) = json_response["response"].as_str() {
382                        let response_string = response_text.to_string();
383
384                        // Estimate tokens (rough: ~4 chars per token)
385                        let estimated_tokens = (prompt.len() + response_string.len()) / 4;
386                        self.stats.record_success(estimated_tokens as u64);
387
388                        // Cache the response if enabled
389                        #[cfg(feature = "dashmap")]
390                        {
391                            if self.config.enable_caching {
392                                self.cache
393                                    .insert(prompt.to_string(), response_string.clone());
394
395                                #[cfg(feature = "tracing")]
396                                tracing::debug!(
397                                    "Cached response for prompt (length: {})",
398                                    prompt.len()
399                                );
400                            }
401                        }
402
403                        return Ok(response_string);
404                    } else {
405                        self.stats.record_failure();
406                        return Err(GraphRAGError::Generation {
407                            message: format!("Invalid response format: {:?}", json_response),
408                        });
409                    }
410                },
411                Err(e) => {
412                    #[cfg(feature = "tracing")]
413                    tracing::warn!("Ollama API request failed (attempt {}): {}", attempt, e);
414                    last_error = Some(e);
415
416                    if attempt < self.config.max_retries {
417                        // Wait before retry (exponential backoff)
418                        tokio::time::sleep(std::time::Duration::from_millis(100 * attempt as u64))
419                            .await;
420                    }
421                },
422            }
423        }
424
425        self.stats.record_failure();
426        Err(GraphRAGError::Generation {
427            message: format!(
428                "Ollama API failed after {} retries: {:?}",
429                self.config.max_retries, last_error
430            ),
431        })
432    }
433
434    /// Generate text and return the full response including KV cache context and token stats.
435    ///
436    /// Use this for the two-step contextual enrichment pattern:
437    ///
438    /// ```no_run
439    /// # use graphrag_core::ollama::{OllamaClient, OllamaConfig, OllamaGenerationParams};
440    /// # async fn example() -> graphrag_core::Result<()> {
441    /// let client = OllamaClient::new(OllamaConfig::default());
442    ///
443    /// // Step 1: Prime — load the document into Ollama's KV cache
444    /// let prime_params = OllamaGenerationParams {
445    ///     num_predict: Some(1), // generate minimal output; we just want the context
446    ///     keep_alive: Some("1h".to_string()),
447    ///     num_ctx: Some(32768),
448    ///     ..Default::default()
449    /// };
450    /// let prime = client.generate_with_full_response("<document>..full doc..</document>", prime_params).await?;
451    /// println!("Prompt tokens evaluated: {}", prime.prompt_eval_count); // ~doc_tokens
452    ///
453    /// // Step 2: Per chunk — only the chunk tokens are evaluated
454    /// for chunk in chunks {
455    ///     let params = OllamaGenerationParams {
456    ///         num_predict: Some(80),
457    ///         context: Some(prime.context.clone()),  // ← KV cache reuse!
458    ///         keep_alive: Some("1h".to_string()),
459    ///         ..Default::default()
460    ///     };
461    ///     let resp = client.generate_with_full_response(&chunk, params).await?;
462    ///     println!("Chunk tokens evaluated: {}", resp.prompt_eval_count); // ~chunk_tokens, not doc_tokens!
463    /// }
464    /// # Ok(())
465    /// # }
466    /// ```
467    #[cfg(feature = "ureq")]
468    pub async fn generate_with_full_response(
469        &self,
470        prompt: &str,
471        params: OllamaGenerationParams,
472    ) -> Result<OllamaGenerateResponse> {
473        let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
474
475        let keep_alive = params
476            .keep_alive
477            .clone()
478            .or_else(|| self.config.keep_alive.clone());
479
480        let mut request_body = serde_json::json!({
481            "model": self.config.chat_model,
482            "prompt": prompt,
483            "stream": false,
484        });
485
486        if let Some(ref ka) = keep_alive {
487            request_body["keep_alive"] = serde_json::Value::String(ka.clone());
488        }
489
490        if let Some(ref ctx) = params.context {
491            request_body["context"] = serde_json::Value::Array(
492                ctx.iter()
493                    .map(|&t| serde_json::Value::Number(t.into()))
494                    .collect(),
495            );
496        }
497
498        let mut options = serde_json::to_value(&params).map_err(|e| GraphRAGError::Generation {
499            message: format!("Failed to serialize generation params: {}", e),
500        })?;
501
502        let effective_num_ctx = params.num_ctx.or(self.config.num_ctx);
503        if let Some(num_ctx) = effective_num_ctx {
504            if let Some(obj) = options.as_object_mut() {
505                obj.insert(
506                    "num_ctx".to_string(),
507                    serde_json::Value::Number(num_ctx.into()),
508                );
509            }
510        }
511
512        if !options.as_object().map_or(true, |o| o.is_empty()) {
513            request_body["options"] = options;
514        }
515
516        let mut last_error = None;
517        for attempt in 1..=self.config.max_retries {
518            match self
519                .client
520                .post(&endpoint)
521                .set("Content-Type", "application/json")
522                .send_json(&request_body)
523            {
524                Ok(response) => {
525                    let json_response: serde_json::Value =
526                        response
527                            .into_json()
528                            .map_err(|e| GraphRAGError::Generation {
529                                message: format!("Failed to parse JSON response: {}", e),
530                            })?;
531
532                    let text = json_response["response"]
533                        .as_str()
534                        .ok_or_else(|| GraphRAGError::Generation {
535                            message: format!("Invalid response format: {:?}", json_response),
536                        })?
537                        .to_string();
538
539                    let context: Vec<i64> = json_response["context"]
540                        .as_array()
541                        .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
542                        .unwrap_or_default();
543
544                    let prompt_eval_count =
545                        json_response["prompt_eval_count"].as_u64().unwrap_or(0);
546                    let eval_count = json_response["eval_count"].as_u64().unwrap_or(0);
547
548                    let estimated_tokens = (prompt.len() + text.len()) / 4;
549                    self.stats.record_success(estimated_tokens as u64);
550
551                    return Ok(OllamaGenerateResponse {
552                        text,
553                        context,
554                        prompt_eval_count,
555                        eval_count,
556                    });
557                },
558                Err(e) => {
559                    last_error = Some(e);
560                    if attempt < self.config.max_retries {
561                        tokio::time::sleep(std::time::Duration::from_millis(100 * attempt as u64))
562                            .await;
563                    }
564                },
565            }
566        }
567
568        self.stats.record_failure();
569        Err(GraphRAGError::Generation {
570            message: format!(
571                "Ollama API failed after {} retries: {:?}",
572                self.config.max_retries, last_error
573            ),
574        })
575    }
576
577    /// Generate streaming completion
578    ///
579    /// Returns a channel receiver that yields tokens as they are generated.
580    /// This enables real-time display of generation progress.
581    ///
582    /// # Example
583    /// ```no_run
584    /// use graphrag_core::ollama::{OllamaClient, OllamaConfig};
585    ///
586    /// # async fn example() -> graphrag_core::Result<()> {
587    /// let client = OllamaClient::new(OllamaConfig::default());
588    /// let mut rx = client.generate_streaming("Write a story").await?;
589    ///
590    /// while let Some(token) = rx.recv().await {
591    ///     print!("{}", token);
592    /// }
593    /// # Ok(())
594    /// # }
595    /// ```
596    #[cfg(all(feature = "ureq", feature = "tokio"))]
597    pub async fn generate_streaming(
598        &self,
599        prompt: &str,
600    ) -> Result<tokio::sync::mpsc::Receiver<String>> {
601        let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
602
603        let params = OllamaGenerationParams {
604            num_predict: self.config.max_tokens,
605            temperature: self.config.temperature,
606            ..Default::default()
607        };
608
609        let mut request_body = serde_json::json!({
610            "model": self.config.chat_model,
611            "prompt": prompt,
612            "stream": true,  // Enable streaming
613        });
614
615        // Add custom parameters
616        let options = serde_json::to_value(&params).map_err(|e| GraphRAGError::Generation {
617            message: format!("Failed to serialize generation params: {}", e),
618        })?;
619
620        if !options.as_object().unwrap().is_empty() {
621            request_body["options"] = options;
622        }
623
624        // Create channel for streaming tokens
625        let (tx, rx) = tokio::sync::mpsc::channel(100);
626
627        // Clone data needed for async task
628        let client = self.client.clone();
629        let stats = self.stats.clone();
630        let prompt_len = prompt.len();
631
632        // Spawn background task to read streaming response
633        tokio::spawn(async move {
634            match client
635                .post(&endpoint)
636                .set("Content-Type", "application/json")
637                .send_json(&request_body)
638            {
639                Ok(response) => {
640                    let reader = std::io::BufReader::new(response.into_reader());
641                    use std::io::BufRead;
642
643                    let mut total_response = String::new();
644
645                    for line in reader.lines() {
646                        match line {
647                            Ok(line_str) => {
648                                if line_str.is_empty() {
649                                    continue;
650                                }
651
652                                // Parse JSON response for this chunk
653                                if let Ok(json) =
654                                    serde_json::from_str::<serde_json::Value>(&line_str)
655                                {
656                                    if let Some(token) = json["response"].as_str() {
657                                        total_response.push_str(token);
658
659                                        // Send token through channel
660                                        if tx.send(token.to_string()).await.is_err() {
661                                            // Receiver dropped, stop streaming
662                                            break;
663                                        }
664                                    }
665
666                                    // Check if done
667                                    if json["done"].as_bool() == Some(true) {
668                                        // Record success
669                                        let estimated_tokens =
670                                            (prompt_len + total_response.len()) / 4;
671                                        stats.record_success(estimated_tokens as u64);
672                                        break;
673                                    }
674                                }
675                            },
676                            Err(e) => {
677                                #[cfg(feature = "tracing")]
678                                tracing::error!("Error reading streaming response: {}", e);
679                                stats.record_failure();
680                                break;
681                            },
682                        }
683                    }
684                },
685                Err(e) => {
686                    #[cfg(feature = "tracing")]
687                    tracing::error!("Failed to initiate streaming request: {}", e);
688                    stats.record_failure();
689                },
690            }
691        });
692
693        Ok(rx)
694    }
695
696    /// Generate text completion (sync fallback when ureq feature is disabled)
697    #[cfg(not(feature = "ureq"))]
698    pub async fn generate(&self, _prompt: &str) -> Result<String> {
699        Err(GraphRAGError::Generation {
700            message: "ureq feature required for Ollama integration".to_string(),
701        })
702    }
703
704    /// Generate with custom parameters (fallback)
705    #[cfg(not(feature = "ureq"))]
706    pub async fn generate_with_params(
707        &self,
708        _prompt: &str,
709        _params: OllamaGenerationParams,
710    ) -> Result<String> {
711        Err(GraphRAGError::Generation {
712            message: "ureq feature required for Ollama integration".to_string(),
713        })
714    }
715}