1#![allow(dead_code)]
17
18use super::{ChatProvider, ProviderType};
19use crate::models::{ChatMessage, ChatRequest, ChatSession};
20use anyhow::Result;
21use serde::{Deserialize, Serialize};
22use std::path::PathBuf;
23
24pub struct OpenAICompatProvider {
26 provider_type: ProviderType,
28 name: String,
30 endpoint: String,
32 api_key: Option<String>,
34 model: Option<String>,
36 available: bool,
38 data_path: Option<PathBuf>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct OpenAIChatMessage {
45 pub role: String,
46 pub content: String,
47}
48
49#[derive(Debug, Serialize)]
51pub struct OpenAIChatRequest {
52 pub model: String,
53 pub messages: Vec<OpenAIChatMessage>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub temperature: Option<f32>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub max_tokens: Option<u32>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub stream: Option<bool>,
60}
61
62#[derive(Debug, Deserialize)]
64pub struct OpenAIChatResponse {
65 pub id: String,
66 pub choices: Vec<OpenAIChatChoice>,
67 #[allow(dead_code)]
68 pub model: String,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct OpenAIChatChoice {
74 pub message: OpenAIChatMessage,
75 #[allow(dead_code)]
76 pub finish_reason: Option<String>,
77}
78
79impl OpenAICompatProvider {
80 pub fn new(
82 provider_type: ProviderType,
83 name: impl Into<String>,
84 endpoint: impl Into<String>,
85 ) -> Self {
86 let endpoint = endpoint.into();
87 Self {
88 provider_type,
89 name: name.into(),
90 endpoint: endpoint.clone(),
91 api_key: None,
92 model: None,
93 available: Self::check_availability(&endpoint),
94 data_path: None,
95 }
96 }
97
98 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
100 self.api_key = Some(api_key.into());
101 self
102 }
103
104 pub fn with_model(mut self, model: impl Into<String>) -> Self {
106 self.model = Some(model.into());
107 self
108 }
109
110 pub fn with_data_path(mut self, path: PathBuf) -> Self {
112 self.data_path = Some(path);
113 self
114 }
115
116 fn check_availability(endpoint: &str) -> bool {
118 !endpoint.is_empty()
120 }
121
122 pub fn session_to_messages(session: &ChatSession) -> Vec<OpenAIChatMessage> {
124 let mut messages = Vec::new();
125
126 for request in &session.requests {
127 if let Some(msg) = &request.message {
129 if let Some(text) = &msg.text {
130 messages.push(OpenAIChatMessage {
131 role: "user".to_string(),
132 content: text.clone(),
133 });
134 }
135 }
136
137 if let Some(response) = &request.response {
139 if let Some(text) = extract_response_text(response) {
140 messages.push(OpenAIChatMessage {
141 role: "assistant".to_string(),
142 content: text,
143 });
144 }
145 }
146 }
147
148 messages
149 }
150
151 pub fn messages_to_session(
153 messages: Vec<OpenAIChatMessage>,
154 model: &str,
155 provider_name: &str,
156 ) -> ChatSession {
157 let now = chrono::Utc::now().timestamp_millis();
158 let session_id = uuid::Uuid::new_v4().to_string();
159
160 let mut requests = Vec::new();
161 let mut user_msg: Option<String> = None;
162
163 for msg in messages {
164 match msg.role.as_str() {
165 "user" => {
166 user_msg = Some(msg.content);
167 }
168 "assistant" => {
169 if let Some(user_text) = user_msg.take() {
170 requests.push(ChatRequest {
171 timestamp: Some(now),
172 message: Some(ChatMessage {
173 text: Some(user_text),
174 parts: None,
175 }),
176 response: Some(serde_json::json!({
177 "value": [{"value": msg.content}]
178 })),
179 variable_data: None,
180 request_id: Some(uuid::Uuid::new_v4().to_string()),
181 response_id: Some(uuid::Uuid::new_v4().to_string()),
182 model_id: Some(model.to_string()),
183 agent: None,
184 result: None,
185 followups: None,
186 is_canceled: Some(false),
187 content_references: None,
188 code_citations: None,
189 response_markdown_info: None,
190 source_session: None,
191 });
192 }
193 }
194 "system" => {
195 }
197 _ => {}
198 }
199 }
200
201 ChatSession {
202 version: 3,
203 session_id: Some(session_id),
204 creation_date: now,
205 last_message_date: now,
206 is_imported: true,
207 initial_location: "api".to_string(),
208 custom_title: Some(format!("{} Chat", provider_name)),
209 requester_username: Some("user".to_string()),
210 requester_avatar_icon_uri: None,
211 responder_username: Some(format!("{}/{}", provider_name, model)),
212 responder_avatar_icon_uri: None,
213 requests,
214 }
215 }
216}
217
218impl ChatProvider for OpenAICompatProvider {
219 fn provider_type(&self) -> ProviderType {
220 self.provider_type
221 }
222
223 fn name(&self) -> &str {
224 &self.name
225 }
226
227 fn is_available(&self) -> bool {
228 self.available
229 }
230
231 fn sessions_path(&self) -> Option<PathBuf> {
232 self.data_path.clone()
233 }
234
235 fn list_sessions(&self) -> Result<Vec<ChatSession>> {
236 Ok(Vec::new())
239 }
240
241 fn import_session(&self, _session_id: &str) -> Result<ChatSession> {
242 anyhow::bail!("{} does not persist chat sessions", self.name)
243 }
244
245 fn export_session(&self, _session: &ChatSession) -> Result<()> {
246 anyhow::bail!("Export to {} not yet implemented", self.name)
248 }
249}
250
251pub fn discover_openai_compatible_providers() -> Vec<OpenAICompatProvider> {
253 let mut providers = Vec::new();
254
255 if let Some(provider) = discover_vllm() {
257 providers.push(provider);
258 }
259
260 if let Some(provider) = discover_lm_studio() {
262 providers.push(provider);
263 }
264
265 if let Some(provider) = discover_localai() {
267 providers.push(provider);
268 }
269
270 if let Some(provider) = discover_text_gen_webui() {
272 providers.push(provider);
273 }
274
275 if let Some(provider) = discover_jan() {
277 providers.push(provider);
278 }
279
280 if let Some(provider) = discover_gpt4all() {
282 providers.push(provider);
283 }
284
285 if let Some(provider) = discover_foundry() {
287 providers.push(provider);
288 }
289
290 providers
291}
292
293fn discover_vllm() -> Option<OpenAICompatProvider> {
294 let endpoint =
295 std::env::var("VLLM_ENDPOINT").unwrap_or_else(|_| "http://localhost:8000/v1".to_string());
296
297 Some(OpenAICompatProvider::new(
298 ProviderType::Vllm,
299 "vLLM",
300 endpoint,
301 ))
302}
303
304fn discover_lm_studio() -> Option<OpenAICompatProvider> {
305 let endpoint = std::env::var("LM_STUDIO_ENDPOINT")
306 .unwrap_or_else(|_| "http://localhost:1234/v1".to_string());
307
308 let data_path = find_lm_studio_data();
310
311 let mut provider = OpenAICompatProvider::new(ProviderType::LmStudio, "LM Studio", endpoint);
312
313 if let Some(path) = data_path {
314 provider = provider.with_data_path(path);
315 }
316
317 Some(provider)
318}
319
320fn discover_localai() -> Option<OpenAICompatProvider> {
321 let endpoint = std::env::var("LOCALAI_ENDPOINT")
322 .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
323
324 Some(OpenAICompatProvider::new(
325 ProviderType::LocalAI,
326 "LocalAI",
327 endpoint,
328 ))
329}
330
331fn discover_text_gen_webui() -> Option<OpenAICompatProvider> {
332 let endpoint = std::env::var("TEXT_GEN_WEBUI_ENDPOINT")
333 .unwrap_or_else(|_| "http://localhost:5000/v1".to_string());
334
335 Some(OpenAICompatProvider::new(
336 ProviderType::TextGenWebUI,
337 "Text Generation WebUI",
338 endpoint,
339 ))
340}
341
342fn discover_jan() -> Option<OpenAICompatProvider> {
343 let endpoint =
344 std::env::var("JAN_ENDPOINT").unwrap_or_else(|_| "http://localhost:1337/v1".to_string());
345
346 let data_path = find_jan_data();
348
349 let mut provider = OpenAICompatProvider::new(ProviderType::Jan, "Jan.ai", endpoint);
350
351 if let Some(path) = data_path {
352 provider = provider.with_data_path(path);
353 }
354
355 Some(provider)
356}
357
358fn discover_gpt4all() -> Option<OpenAICompatProvider> {
359 let endpoint = std::env::var("GPT4ALL_ENDPOINT")
360 .unwrap_or_else(|_| "http://localhost:4891/v1".to_string());
361
362 let data_path = find_gpt4all_data();
364
365 let mut provider = OpenAICompatProvider::new(ProviderType::Gpt4All, "GPT4All", endpoint);
366
367 if let Some(path) = data_path {
368 provider = provider.with_data_path(path);
369 }
370
371 Some(provider)
372}
373
374fn discover_foundry() -> Option<OpenAICompatProvider> {
375 let endpoint = std::env::var("FOUNDRY_LOCAL_ENDPOINT")
377 .or_else(|_| std::env::var("AI_FOUNDRY_ENDPOINT"))
378 .unwrap_or_else(|_| "http://localhost:5272/v1".to_string());
379
380 Some(OpenAICompatProvider::new(
381 ProviderType::Foundry,
382 "Azure AI Foundry",
383 endpoint,
384 ))
385}
386
387fn find_lm_studio_data() -> Option<PathBuf> {
390 #[cfg(target_os = "windows")]
391 {
392 let home = dirs::home_dir()?;
393 let path = home.join(".cache").join("lm-studio");
394 if path.exists() {
395 return Some(path);
396 }
397 }
398
399 #[cfg(target_os = "macos")]
400 {
401 let home = dirs::home_dir()?;
402 let path = home.join(".cache").join("lm-studio");
403 if path.exists() {
404 return Some(path);
405 }
406 }
407
408 #[cfg(target_os = "linux")]
409 {
410 if let Some(cache_dir) = dirs::cache_dir() {
411 let path = cache_dir.join("lm-studio");
412 if path.exists() {
413 return Some(path);
414 }
415 }
416 }
417
418 None
419}
420
421fn find_jan_data() -> Option<PathBuf> {
422 #[cfg(target_os = "windows")]
423 {
424 let home = dirs::home_dir()?;
425 let path = home.join("jan");
426 if path.exists() {
427 return Some(path);
428 }
429 }
430
431 #[cfg(target_os = "macos")]
432 {
433 let home = dirs::home_dir()?;
434 let path = home.join("jan");
435 if path.exists() {
436 return Some(path);
437 }
438 }
439
440 #[cfg(target_os = "linux")]
441 {
442 let home = dirs::home_dir()?;
443 let path = home.join("jan");
444 if path.exists() {
445 return Some(path);
446 }
447 }
448
449 None
450}
451
452fn find_gpt4all_data() -> Option<PathBuf> {
453 #[cfg(target_os = "windows")]
454 {
455 let local_app_data = dirs::data_local_dir()?;
456 let path = local_app_data.join("nomic.ai").join("GPT4All");
457 if path.exists() {
458 return Some(path);
459 }
460 }
461
462 #[cfg(target_os = "macos")]
463 {
464 let home = dirs::home_dir()?;
465 let path = home
466 .join("Library")
467 .join("Application Support")
468 .join("nomic.ai")
469 .join("GPT4All");
470 if path.exists() {
471 return Some(path);
472 }
473 }
474
475 #[cfg(target_os = "linux")]
476 {
477 if let Some(data_dir) = dirs::data_dir() {
478 let path = data_dir.join("nomic.ai").join("GPT4All");
479 if path.exists() {
480 return Some(path);
481 }
482 }
483 }
484
485 None
486}
487
488fn extract_response_text(response: &serde_json::Value) -> Option<String> {
490 if let Some(text) = response.get("text").and_then(|v| v.as_str()) {
492 return Some(text.to_string());
493 }
494
495 if let Some(value) = response.get("value").and_then(|v| v.as_array()) {
497 let parts: Vec<String> = value
498 .iter()
499 .filter_map(|v| v.get("value").and_then(|v| v.as_str()))
500 .map(String::from)
501 .collect();
502 if !parts.is_empty() {
503 return Some(parts.join("\n"));
504 }
505 }
506
507 if let Some(content) = response.get("content").and_then(|v| v.as_str()) {
509 return Some(content.to_string());
510 }
511
512 None
513}