graphrag_core/ollama/
mod.rs1use crate::core::{GraphRAGError, Result};
6
7#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct OllamaConfig {
10 pub enabled: bool,
12 pub host: String,
14 pub port: u16,
16 pub embedding_model: String,
18 pub chat_model: String,
20 pub timeout_seconds: u64,
22 pub max_retries: u32,
24 pub fallback_to_hash: bool,
26 pub max_tokens: Option<u32>,
28 pub temperature: Option<f32>,
30}
31
32impl Default for OllamaConfig {
33 fn default() -> Self {
34 Self {
35 enabled: false,
36 host: "http://localhost".to_string(),
37 port: 11434,
38 embedding_model: "nomic-embed-text".to_string(),
39 chat_model: "llama3.2:3b".to_string(),
40 timeout_seconds: 30,
41 max_retries: 3,
42 fallback_to_hash: true,
43 max_tokens: Some(2000),
44 temperature: Some(0.7),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct OllamaClient {
52 config: OllamaConfig,
53 #[cfg(feature = "ureq")]
54 client: ureq::Agent,
55}
56
57impl OllamaClient {
58 pub fn new(config: OllamaConfig) -> Self {
60 Self {
61 config: config.clone(),
62 #[cfg(feature = "ureq")]
63 client: ureq::AgentBuilder::new()
64 .timeout(std::time::Duration::from_secs(config.timeout_seconds))
65 .build(),
66 }
67 }
68
69 #[cfg(feature = "ureq")]
71 pub async fn generate(&self, prompt: &str) -> Result<String> {
72 let endpoint = format!("{}:{}/api/generate", self.config.host, self.config.port);
73
74 let mut request_body = serde_json::json!({
75 "model": self.config.chat_model,
76 "prompt": prompt,
77 "stream": false,
78 });
79
80 if let Some(max_tokens) = self.config.max_tokens {
82 request_body["options"] = serde_json::json!({
83 "num_predict": max_tokens,
84 });
85 }
86
87 if let Some(temperature) = self.config.temperature {
88 if request_body.get("options").is_none() {
89 request_body["options"] = serde_json::json!({});
90 }
91 request_body["options"]["temperature"] = serde_json::json!(temperature);
92 }
93
94 let mut last_error = None;
96 for attempt in 1..=self.config.max_retries {
97 match self.client
98 .post(&endpoint)
99 .set("Content-Type", "application/json")
100 .send_json(&request_body)
101 {
102 Ok(response) => {
103 let json_response: serde_json::Value = response
104 .into_json()
105 .map_err(|e| GraphRAGError::Generation {
106 message: format!("Failed to parse JSON response: {}", e),
107 })?;
108
109 if let Some(response_text) = json_response["response"].as_str() {
111 return Ok(response_text.to_string());
112 } else {
113 return Err(GraphRAGError::Generation {
114 message: format!("Invalid response format: {:?}", json_response),
115 });
116 }
117 }
118 Err(e) => {
119 tracing::warn!("Ollama API request failed (attempt {}): {}", attempt, e);
120 last_error = Some(e);
121
122 if attempt < self.config.max_retries {
123 tokio::time::sleep(std::time::Duration::from_millis(100 * attempt as u64)).await;
125 }
126 }
127 }
128 }
129
130 Err(GraphRAGError::Generation {
131 message: format!("Ollama API failed after {} retries: {:?}",
132 self.config.max_retries, last_error),
133 })
134 }
135
136 #[cfg(not(feature = "ureq"))]
138 pub async fn generate(&self, _prompt: &str) -> Result<String> {
139 Err(GraphRAGError::Generation {
140 message: "ureq feature required for Ollama integration".to_string(),
141 })
142 }
143}