1#![allow(dead_code)]
17
18use super::{ChatProvider, ProviderType};
19use crate::models::{ChatMessage, ChatRequest, ChatSession};
20use anyhow::Result;
21use serde::{Deserialize, Serialize};
22use std::path::PathBuf;
23
24pub struct OpenAICompatProvider {
26 provider_type: ProviderType,
28 name: String,
30 endpoint: String,
32 api_key: Option<String>,
34 model: Option<String>,
36 available: bool,
38 data_path: Option<PathBuf>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct OpenAIChatMessage {
45 pub role: String,
46 pub content: String,
47}
48
49#[derive(Debug, Serialize)]
51pub struct OpenAIChatRequest {
52 pub model: String,
53 pub messages: Vec<OpenAIChatMessage>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub temperature: Option<f32>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub max_tokens: Option<u32>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub stream: Option<bool>,
60}
61
62#[derive(Debug, Deserialize)]
64pub struct OpenAIChatResponse {
65 pub id: String,
66 pub choices: Vec<OpenAIChatChoice>,
67 #[allow(dead_code)]
68 pub model: String,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct OpenAIChatChoice {
74 pub message: OpenAIChatMessage,
75 #[allow(dead_code)]
76 pub finish_reason: Option<String>,
77}
78
79impl OpenAICompatProvider {
80 pub fn new(
82 provider_type: ProviderType,
83 name: impl Into<String>,
84 endpoint: impl Into<String>,
85 ) -> Self {
86 let endpoint = endpoint.into();
87 Self {
88 provider_type,
89 name: name.into(),
90 endpoint: endpoint.clone(),
91 api_key: None,
92 model: None,
93 available: Self::check_availability(&endpoint),
94 data_path: None,
95 }
96 }
97
98 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
100 self.api_key = Some(api_key.into());
101 self
102 }
103
104 pub fn with_model(mut self, model: impl Into<String>) -> Self {
106 self.model = Some(model.into());
107 self
108 }
109
110 pub fn with_data_path(mut self, path: PathBuf) -> Self {
112 self.data_path = Some(path);
113 self
114 }
115
116 fn check_availability(endpoint: &str) -> bool {
118 !endpoint.is_empty()
120 }
121
122 pub fn session_to_messages(session: &ChatSession) -> Vec<OpenAIChatMessage> {
124 let mut messages = Vec::new();
125
126 for request in &session.requests {
127 if let Some(msg) = &request.message {
129 if let Some(text) = &msg.text {
130 messages.push(OpenAIChatMessage {
131 role: "user".to_string(),
132 content: text.clone(),
133 });
134 }
135 }
136
137 if let Some(response) = &request.response {
139 if let Some(text) = extract_response_text(response) {
140 messages.push(OpenAIChatMessage {
141 role: "assistant".to_string(),
142 content: text,
143 });
144 }
145 }
146 }
147
148 messages
149 }
150
151 pub fn messages_to_session(
153 messages: Vec<OpenAIChatMessage>,
154 model: &str,
155 provider_name: &str,
156 ) -> ChatSession {
157 let now = chrono::Utc::now().timestamp_millis();
158 let session_id = uuid::Uuid::new_v4().to_string();
159
160 let mut requests = Vec::new();
161 let mut user_msg: Option<String> = None;
162
163 for msg in messages {
164 match msg.role.as_str() {
165 "user" => {
166 user_msg = Some(msg.content);
167 }
168 "assistant" => {
169 if let Some(user_text) = user_msg.take() {
170 requests.push(ChatRequest {
171 timestamp: Some(now),
172 message: Some(ChatMessage {
173 text: Some(user_text),
174 parts: None,
175 }),
176 response: Some(serde_json::json!({
177 "value": [{"value": msg.content}]
178 })),
179 variable_data: None,
180 request_id: Some(uuid::Uuid::new_v4().to_string()),
181 response_id: Some(uuid::Uuid::new_v4().to_string()),
182 model_id: Some(model.to_string()),
183 agent: None,
184 result: None,
185 followups: None,
186 is_canceled: Some(false),
187 content_references: None,
188 code_citations: None,
189 response_markdown_info: None,
190 source_session: None,
191 model_state: None,
192 time_spent_waiting: None,
193 });
194 }
195 }
196 "system" => {
197 }
199 _ => {}
200 }
201 }
202
203 ChatSession {
204 version: 3,
205 session_id: Some(session_id),
206 creation_date: now,
207 last_message_date: now,
208 is_imported: true,
209 initial_location: "api".to_string(),
210 custom_title: Some(format!("{} Chat", provider_name)),
211 requester_username: Some("user".to_string()),
212 requester_avatar_icon_uri: None,
213 responder_username: Some(format!("{}/{}", provider_name, model)),
214 responder_avatar_icon_uri: None,
215 requests,
216 }
217 }
218}
219
220impl ChatProvider for OpenAICompatProvider {
221 fn provider_type(&self) -> ProviderType {
222 self.provider_type
223 }
224
225 fn name(&self) -> &str {
226 &self.name
227 }
228
229 fn is_available(&self) -> bool {
230 self.available
231 }
232
233 fn sessions_path(&self) -> Option<PathBuf> {
234 self.data_path.clone()
235 }
236
237 fn list_sessions(&self) -> Result<Vec<ChatSession>> {
238 Ok(Vec::new())
241 }
242
243 fn import_session(&self, _session_id: &str) -> Result<ChatSession> {
244 anyhow::bail!("{} does not persist chat sessions", self.name)
245 }
246
247 fn export_session(&self, _session: &ChatSession) -> Result<()> {
248 anyhow::bail!("Export to {} not yet implemented", self.name)
250 }
251}
252
253pub fn discover_openai_compatible_providers() -> Vec<OpenAICompatProvider> {
255 let mut providers = Vec::new();
256
257 if let Some(provider) = discover_vllm() {
259 providers.push(provider);
260 }
261
262 if let Some(provider) = discover_lm_studio() {
264 providers.push(provider);
265 }
266
267 if let Some(provider) = discover_localai() {
269 providers.push(provider);
270 }
271
272 if let Some(provider) = discover_text_gen_webui() {
274 providers.push(provider);
275 }
276
277 if let Some(provider) = discover_jan() {
279 providers.push(provider);
280 }
281
282 if let Some(provider) = discover_gpt4all() {
284 providers.push(provider);
285 }
286
287 if let Some(provider) = discover_foundry() {
289 providers.push(provider);
290 }
291
292 providers
293}
294
295fn discover_vllm() -> Option<OpenAICompatProvider> {
296 let endpoint =
297 std::env::var("VLLM_ENDPOINT").unwrap_or_else(|_| "http://localhost:8000/v1".to_string());
298
299 Some(OpenAICompatProvider::new(
300 ProviderType::Vllm,
301 "vLLM",
302 endpoint,
303 ))
304}
305
306fn discover_lm_studio() -> Option<OpenAICompatProvider> {
307 let endpoint = std::env::var("LM_STUDIO_ENDPOINT")
308 .unwrap_or_else(|_| "http://localhost:1234/v1".to_string());
309
310 let data_path = find_lm_studio_data();
312
313 let mut provider = OpenAICompatProvider::new(ProviderType::LmStudio, "LM Studio", endpoint);
314
315 if let Some(path) = data_path {
316 provider = provider.with_data_path(path);
317 }
318
319 Some(provider)
320}
321
322fn discover_localai() -> Option<OpenAICompatProvider> {
323 let endpoint = std::env::var("LOCALAI_ENDPOINT")
324 .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
325
326 Some(OpenAICompatProvider::new(
327 ProviderType::LocalAI,
328 "LocalAI",
329 endpoint,
330 ))
331}
332
333fn discover_text_gen_webui() -> Option<OpenAICompatProvider> {
334 let endpoint = std::env::var("TEXT_GEN_WEBUI_ENDPOINT")
335 .unwrap_or_else(|_| "http://localhost:5000/v1".to_string());
336
337 Some(OpenAICompatProvider::new(
338 ProviderType::TextGenWebUI,
339 "Text Generation WebUI",
340 endpoint,
341 ))
342}
343
344fn discover_jan() -> Option<OpenAICompatProvider> {
345 let endpoint =
346 std::env::var("JAN_ENDPOINT").unwrap_or_else(|_| "http://localhost:1337/v1".to_string());
347
348 let data_path = find_jan_data();
350
351 let mut provider = OpenAICompatProvider::new(ProviderType::Jan, "Jan.ai", endpoint);
352
353 if let Some(path) = data_path {
354 provider = provider.with_data_path(path);
355 }
356
357 Some(provider)
358}
359
360fn discover_gpt4all() -> Option<OpenAICompatProvider> {
361 let endpoint = std::env::var("GPT4ALL_ENDPOINT")
362 .unwrap_or_else(|_| "http://localhost:4891/v1".to_string());
363
364 let data_path = find_gpt4all_data();
366
367 let mut provider = OpenAICompatProvider::new(ProviderType::Gpt4All, "GPT4All", endpoint);
368
369 if let Some(path) = data_path {
370 provider = provider.with_data_path(path);
371 }
372
373 Some(provider)
374}
375
376fn discover_foundry() -> Option<OpenAICompatProvider> {
377 let endpoint = std::env::var("FOUNDRY_LOCAL_ENDPOINT")
379 .or_else(|_| std::env::var("AI_FOUNDRY_ENDPOINT"))
380 .unwrap_or_else(|_| "http://localhost:5272/v1".to_string());
381
382 Some(OpenAICompatProvider::new(
383 ProviderType::Foundry,
384 "Azure AI Foundry",
385 endpoint,
386 ))
387}
388
389fn find_lm_studio_data() -> Option<PathBuf> {
392 #[cfg(target_os = "windows")]
393 {
394 let home = dirs::home_dir()?;
395 let path = home.join(".cache").join("lm-studio");
396 if path.exists() {
397 return Some(path);
398 }
399 }
400
401 #[cfg(target_os = "macos")]
402 {
403 let home = dirs::home_dir()?;
404 let path = home.join(".cache").join("lm-studio");
405 if path.exists() {
406 return Some(path);
407 }
408 }
409
410 #[cfg(target_os = "linux")]
411 {
412 if let Some(cache_dir) = dirs::cache_dir() {
413 let path = cache_dir.join("lm-studio");
414 if path.exists() {
415 return Some(path);
416 }
417 }
418 }
419
420 None
421}
422
423fn find_jan_data() -> Option<PathBuf> {
424 #[cfg(target_os = "windows")]
425 {
426 let home = dirs::home_dir()?;
427 let path = home.join("jan");
428 if path.exists() {
429 return Some(path);
430 }
431 }
432
433 #[cfg(target_os = "macos")]
434 {
435 let home = dirs::home_dir()?;
436 let path = home.join("jan");
437 if path.exists() {
438 return Some(path);
439 }
440 }
441
442 #[cfg(target_os = "linux")]
443 {
444 let home = dirs::home_dir()?;
445 let path = home.join("jan");
446 if path.exists() {
447 return Some(path);
448 }
449 }
450
451 None
452}
453
454fn find_gpt4all_data() -> Option<PathBuf> {
455 #[cfg(target_os = "windows")]
456 {
457 let local_app_data = dirs::data_local_dir()?;
458 let path = local_app_data.join("nomic.ai").join("GPT4All");
459 if path.exists() {
460 return Some(path);
461 }
462 }
463
464 #[cfg(target_os = "macos")]
465 {
466 let home = dirs::home_dir()?;
467 let path = home
468 .join("Library")
469 .join("Application Support")
470 .join("nomic.ai")
471 .join("GPT4All");
472 if path.exists() {
473 return Some(path);
474 }
475 }
476
477 #[cfg(target_os = "linux")]
478 {
479 if let Some(data_dir) = dirs::data_dir() {
480 let path = data_dir.join("nomic.ai").join("GPT4All");
481 if path.exists() {
482 return Some(path);
483 }
484 }
485 }
486
487 None
488}
489
490fn extract_response_text(response: &serde_json::Value) -> Option<String> {
492 if let Some(text) = response.get("text").and_then(|v| v.as_str()) {
494 return Some(text.to_string());
495 }
496
497 if let Some(value) = response.get("value").and_then(|v| v.as_array()) {
499 let parts: Vec<String> = value
500 .iter()
501 .filter_map(|v| v.get("value").and_then(|v| v.as_str()))
502 .map(String::from)
503 .collect();
504 if !parts.is_empty() {
505 return Some(parts.join("\n"));
506 }
507 }
508
509 if let Some(content) = response.get("content").and_then(|v| v.as_str()) {
511 return Some(content.to_string());
512 }
513
514 None
515}