1use crate::core::{GraphRAGError, Result};
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct OllamaGenerationParams {
12 #[serde(skip_serializing_if = "Option::is_none")]
14 pub num_predict: Option<u32>,
15 #[serde(skip_serializing_if = "Option::is_none")]
17 pub temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
20 pub top_p: Option<f32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub top_k: Option<u32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub stop: Option<Vec<String>>,
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub repeat_penalty: Option<f32>,
30 #[serde(skip_serializing_if = "Option::is_none")]
39 pub num_ctx: Option<u32>,
40 #[serde(skip)]
49 pub keep_alive: Option<String>,
50
51 #[serde(skip)]
62 pub context: Option<Vec<i64>>,
63}
64
65impl Default for OllamaGenerationParams {
66 fn default() -> Self {
67 Self {
68 num_predict: Some(2000),
69 temperature: Some(0.7),
70 top_p: Some(0.9),
71 top_k: Some(40),
72 stop: None,
73 repeat_penalty: Some(1.1),
74 num_ctx: None,
75 keep_alive: None,
76 context: None,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
86pub struct OllamaGenerateResponse {
87 pub text: String,
89 pub context: Vec<i64>,
92 pub prompt_eval_count: u64,
95 pub eval_count: u64,
97}
98
99#[derive(Debug, Clone, Default)]
101pub struct OllamaUsageStats {
102 pub total_requests: Arc<AtomicU64>,
104 pub successful_requests: Arc<AtomicU64>,
106 pub failed_requests: Arc<AtomicU64>,
108 pub total_tokens: Arc<AtomicU64>,
110}
111
112impl OllamaUsageStats {
113 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn record_success(&self, tokens: u64) {
120 self.total_requests.fetch_add(1, Ordering::Relaxed);
121 self.successful_requests.fetch_add(1, Ordering::Relaxed);
122 self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
123 }
124
125 pub fn record_failure(&self) {
127 self.total_requests.fetch_add(1, Ordering::Relaxed);
128 self.failed_requests.fetch_add(1, Ordering::Relaxed);
129 }
130
131 pub fn get_total_requests(&self) -> u64 {
133 self.total_requests.load(Ordering::Relaxed)
134 }
135
136 pub fn get_successful_requests(&self) -> u64 {
138 self.successful_requests.load(Ordering::Relaxed)
139 }
140
141 pub fn get_failed_requests(&self) -> u64 {
143 self.failed_requests.load(Ordering::Relaxed)
144 }
145
146 pub fn get_total_tokens(&self) -> u64 {
148 self.total_tokens.load(Ordering::Relaxed)
149 }
150
151 pub fn get_success_rate(&self) -> f64 {
153 let total = self.get_total_requests();
154 if total == 0 {
155 return 0.0;
156 }
157 self.get_successful_requests() as f64 / total as f64
158 }
159}
160
161#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
163pub struct OllamaConfig {
164 pub enabled: bool,
166 pub host: String,
168 pub port: u16,
170 pub embedding_model: String,
172 pub chat_model: String,
174 pub timeout_seconds: u64,
176 pub max_retries: u32,
178 pub fallback_to_hash: bool,
180 pub max_tokens: Option<u32>,
182 pub temperature: Option<f32>,
184 pub enable_caching: bool,
186 #[serde(skip_serializing_if = "Option::is_none")]
192 pub keep_alive: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
200 pub num_ctx: Option<u32>,
201}
202
203impl Default for OllamaConfig {
204 fn default() -> Self {
205 Self {
206 enabled: false,
207 host: "http://localhost".to_string(),
208 port: 11434,
209 embedding_model: "nomic-embed-text".to_string(),
210 chat_model: "llama3.2:3b".to_string(),
211 timeout_seconds: 30,
212 max_retries: 3,
213 fallback_to_hash: true,
214 max_tokens: Some(2000),
215 temperature: Some(0.7),
216 enable_caching: true,
217 keep_alive: None,
218 num_ctx: None,
219 }
220 }
221}
222
223#[derive(Clone)]
225pub struct OllamaClient {
226 config: OllamaConfig,
227 #[cfg(feature = "ureq")]
228 client: ureq::Agent,
229 stats: OllamaUsageStats,
231 #[cfg(feature = "dashmap")]
233 cache: Arc<dashmap::DashMap<String, String>>,
234}
235
236impl std::fmt::Debug for OllamaClient {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("OllamaClient")
239 .field("config", &self.config)
240 .field("stats", &self.stats)
241 .finish()
242 }
243}
244
245impl OllamaClient {
246 pub fn new(config: OllamaConfig) -> Self {
248 Self {
249 config: config.clone(),
250 #[cfg(feature = "ureq")]
251 client: ureq::AgentBuilder::new()
252 .timeout(std::time::Duration::from_secs(config.timeout_seconds))
253 .build(),
254 stats: OllamaUsageStats::new(),
255 #[cfg(feature = "dashmap")]
256 cache: Arc::new(dashmap::DashMap::new()),
257 }
258 }
259
260 pub fn get_stats(&self) -> &OllamaUsageStats {
262 &self.stats
263 }
264
265 pub fn config(&self) -> &OllamaConfig {
267 &self.config
268 }
269
270 #[cfg(feature = "dashmap")]
272 pub fn clear_cache(&self) {
273 self.cache.clear();
274 }
275
276 #[cfg(feature = "dashmap")]
278 pub fn cache_size(&self) -> usize {
279 self.cache.len()
280 }
281
282 #[cfg(feature = "ureq")]
284 pub async fn generate(&self, prompt: &str) -> Result<String> {
285 #[cfg(feature = "dashmap")]
287 {
288 if self.config.enable_caching {
289 if let Some(cached_response) = self.cache.get(prompt) {
290 #[cfg(feature = "tracing")]
291 tracing::debug!("Cache hit for prompt (length: {})", prompt.len());
292 return Ok(cached_response.clone());
293 }
294 }
295 }
296
297 let params = OllamaGenerationParams {
299 num_predict: self.config.max_tokens,
300 temperature: self.config.temperature,
301 ..Default::default()
302 };
303
304 self.generate_with_params(prompt, params).await
305 }
306
307 #[cfg(feature = "ureq")]
309 pub async fn generate_with_params(
310 &self,
311 prompt: &str,
312 params: OllamaGenerationParams,
313 ) -> Result<String> {
314 let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
315
316 let keep_alive = params
318 .keep_alive
319 .clone()
320 .or_else(|| self.config.keep_alive.clone());
321
322 let mut request_body = serde_json::json!({
323 "model": self.config.chat_model,
324 "prompt": prompt,
325 "stream": false,
326 });
327
328 if let Some(ref ka) = keep_alive {
330 request_body["keep_alive"] = serde_json::Value::String(ka.clone());
331 }
332
333 if let Some(ref ctx) = params.context {
336 request_body["context"] = serde_json::Value::Array(
337 ctx.iter()
338 .map(|&t| serde_json::Value::Number(t.into()))
339 .collect(),
340 );
341 }
342
343 let mut options = serde_json::to_value(¶ms).map_err(|e| GraphRAGError::Generation {
345 message: format!("Failed to serialize generation params: {}", e),
346 })?;
347
348 let effective_num_ctx = params.num_ctx.or(self.config.num_ctx);
350 if let Some(num_ctx) = effective_num_ctx {
351 if let Some(obj) = options.as_object_mut() {
352 obj.insert(
353 "num_ctx".to_string(),
354 serde_json::Value::Number(num_ctx.into()),
355 );
356 }
357 }
358
359 if !options.as_object().map_or(true, |o| o.is_empty()) {
360 request_body["options"] = options;
361 }
362
363 let mut last_error = None;
365 for attempt in 1..=self.config.max_retries {
366 match self
367 .client
368 .post(&endpoint)
369 .set("Content-Type", "application/json")
370 .send_json(&request_body)
371 {
372 Ok(response) => {
373 let json_response: serde_json::Value =
374 response
375 .into_json()
376 .map_err(|e| GraphRAGError::Generation {
377 message: format!("Failed to parse JSON response: {}", e),
378 })?;
379
380 if let Some(response_text) = json_response["response"].as_str() {
382 let response_string = response_text.to_string();
383
384 let estimated_tokens = (prompt.len() + response_string.len()) / 4;
386 self.stats.record_success(estimated_tokens as u64);
387
388 #[cfg(feature = "dashmap")]
390 {
391 if self.config.enable_caching {
392 self.cache
393 .insert(prompt.to_string(), response_string.clone());
394
395 #[cfg(feature = "tracing")]
396 tracing::debug!(
397 "Cached response for prompt (length: {})",
398 prompt.len()
399 );
400 }
401 }
402
403 return Ok(response_string);
404 } else {
405 self.stats.record_failure();
406 return Err(GraphRAGError::Generation {
407 message: format!("Invalid response format: {:?}", json_response),
408 });
409 }
410 },
411 Err(e) => {
412 #[cfg(feature = "tracing")]
413 tracing::warn!("Ollama API request failed (attempt {}): {}", attempt, e);
414 last_error = Some(e);
415
416 if attempt < self.config.max_retries {
417 tokio::time::sleep(std::time::Duration::from_millis(100 * attempt as u64))
419 .await;
420 }
421 },
422 }
423 }
424
425 self.stats.record_failure();
426 Err(GraphRAGError::Generation {
427 message: format!(
428 "Ollama API failed after {} retries: {:?}",
429 self.config.max_retries, last_error
430 ),
431 })
432 }
433
434 #[cfg(feature = "ureq")]
468 pub async fn generate_with_full_response(
469 &self,
470 prompt: &str,
471 params: OllamaGenerationParams,
472 ) -> Result<OllamaGenerateResponse> {
473 let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
474
475 let keep_alive = params
476 .keep_alive
477 .clone()
478 .or_else(|| self.config.keep_alive.clone());
479
480 let mut request_body = serde_json::json!({
481 "model": self.config.chat_model,
482 "prompt": prompt,
483 "stream": false,
484 });
485
486 if let Some(ref ka) = keep_alive {
487 request_body["keep_alive"] = serde_json::Value::String(ka.clone());
488 }
489
490 if let Some(ref ctx) = params.context {
491 request_body["context"] = serde_json::Value::Array(
492 ctx.iter()
493 .map(|&t| serde_json::Value::Number(t.into()))
494 .collect(),
495 );
496 }
497
498 let mut options = serde_json::to_value(¶ms).map_err(|e| GraphRAGError::Generation {
499 message: format!("Failed to serialize generation params: {}", e),
500 })?;
501
502 let effective_num_ctx = params.num_ctx.or(self.config.num_ctx);
503 if let Some(num_ctx) = effective_num_ctx {
504 if let Some(obj) = options.as_object_mut() {
505 obj.insert(
506 "num_ctx".to_string(),
507 serde_json::Value::Number(num_ctx.into()),
508 );
509 }
510 }
511
512 if !options.as_object().map_or(true, |o| o.is_empty()) {
513 request_body["options"] = options;
514 }
515
516 let mut last_error = None;
517 for attempt in 1..=self.config.max_retries {
518 match self
519 .client
520 .post(&endpoint)
521 .set("Content-Type", "application/json")
522 .send_json(&request_body)
523 {
524 Ok(response) => {
525 let json_response: serde_json::Value =
526 response
527 .into_json()
528 .map_err(|e| GraphRAGError::Generation {
529 message: format!("Failed to parse JSON response: {}", e),
530 })?;
531
532 let text = json_response["response"]
533 .as_str()
534 .ok_or_else(|| GraphRAGError::Generation {
535 message: format!("Invalid response format: {:?}", json_response),
536 })?
537 .to_string();
538
539 let context: Vec<i64> = json_response["context"]
540 .as_array()
541 .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
542 .unwrap_or_default();
543
544 let prompt_eval_count =
545 json_response["prompt_eval_count"].as_u64().unwrap_or(0);
546 let eval_count = json_response["eval_count"].as_u64().unwrap_or(0);
547
548 let estimated_tokens = (prompt.len() + text.len()) / 4;
549 self.stats.record_success(estimated_tokens as u64);
550
551 return Ok(OllamaGenerateResponse {
552 text,
553 context,
554 prompt_eval_count,
555 eval_count,
556 });
557 },
558 Err(e) => {
559 last_error = Some(e);
560 if attempt < self.config.max_retries {
561 tokio::time::sleep(std::time::Duration::from_millis(100 * attempt as u64))
562 .await;
563 }
564 },
565 }
566 }
567
568 self.stats.record_failure();
569 Err(GraphRAGError::Generation {
570 message: format!(
571 "Ollama API failed after {} retries: {:?}",
572 self.config.max_retries, last_error
573 ),
574 })
575 }
576
577 #[cfg(all(feature = "ureq", feature = "tokio"))]
597 pub async fn generate_streaming(
598 &self,
599 prompt: &str,
600 ) -> Result<tokio::sync::mpsc::Receiver<String>> {
601 let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
602
603 let params = OllamaGenerationParams {
604 num_predict: self.config.max_tokens,
605 temperature: self.config.temperature,
606 ..Default::default()
607 };
608
609 let mut request_body = serde_json::json!({
610 "model": self.config.chat_model,
611 "prompt": prompt,
612 "stream": true, });
614
615 let options = serde_json::to_value(¶ms).map_err(|e| GraphRAGError::Generation {
617 message: format!("Failed to serialize generation params: {}", e),
618 })?;
619
620 if !options.as_object().unwrap().is_empty() {
621 request_body["options"] = options;
622 }
623
624 let (tx, rx) = tokio::sync::mpsc::channel(100);
626
627 let client = self.client.clone();
629 let stats = self.stats.clone();
630 let prompt_len = prompt.len();
631
632 tokio::spawn(async move {
634 match client
635 .post(&endpoint)
636 .set("Content-Type", "application/json")
637 .send_json(&request_body)
638 {
639 Ok(response) => {
640 let reader = std::io::BufReader::new(response.into_reader());
641 use std::io::BufRead;
642
643 let mut total_response = String::new();
644
645 for line in reader.lines() {
646 match line {
647 Ok(line_str) => {
648 if line_str.is_empty() {
649 continue;
650 }
651
652 if let Ok(json) =
654 serde_json::from_str::<serde_json::Value>(&line_str)
655 {
656 if let Some(token) = json["response"].as_str() {
657 total_response.push_str(token);
658
659 if tx.send(token.to_string()).await.is_err() {
661 break;
663 }
664 }
665
666 if json["done"].as_bool() == Some(true) {
668 let estimated_tokens =
670 (prompt_len + total_response.len()) / 4;
671 stats.record_success(estimated_tokens as u64);
672 break;
673 }
674 }
675 },
676 Err(e) => {
677 #[cfg(feature = "tracing")]
678 tracing::error!("Error reading streaming response: {}", e);
679 stats.record_failure();
680 break;
681 },
682 }
683 }
684 },
685 Err(e) => {
686 #[cfg(feature = "tracing")]
687 tracing::error!("Failed to initiate streaming request: {}", e);
688 stats.record_failure();
689 },
690 }
691 });
692
693 Ok(rx)
694 }
695
696 #[cfg(not(feature = "ureq"))]
698 pub async fn generate(&self, _prompt: &str) -> Result<String> {
699 Err(GraphRAGError::Generation {
700 message: "ureq feature required for Ollama integration".to_string(),
701 })
702 }
703
704 #[cfg(not(feature = "ureq"))]
706 pub async fn generate_with_params(
707 &self,
708 _prompt: &str,
709 _params: OllamaGenerationParams,
710 ) -> Result<String> {
711 Err(GraphRAGError::Generation {
712 message: "ureq feature required for Ollama integration".to_string(),
713 })
714 }
715}