Skip to main content

chasm/providers/
openai_compat.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: AGPL-3.0-only
3//! OpenAI-compatible provider support
4//!
5//! Supports servers that implement the OpenAI Chat Completions API:
6//! - vLLM
7//! - LM Studio
8//! - LocalAI
9//! - Text Generation WebUI
10//! - Jan.ai
11//! - GPT4All
12//! - Llamafile
13//! - Azure AI Foundry (Foundry Local)
14//! - Any custom OpenAI-compatible endpoint
15
16#![allow(dead_code)]
17
18use super::{ChatProvider, ProviderType};
19use crate::models::{ChatMessage, ChatRequest, ChatSession};
20use anyhow::Result;
21use serde::{Deserialize, Serialize};
22use std::path::PathBuf;
23
24/// OpenAI-compatible API provider
25pub struct OpenAICompatProvider {
26    /// Provider type
27    provider_type: ProviderType,
28    /// Display name
29    name: String,
30    /// API endpoint URL
31    endpoint: String,
32    /// API key (if required)
33    api_key: Option<String>,
34    /// Default model
35    model: Option<String>,
36    /// Whether the endpoint is available
37    available: bool,
38    /// Local data path (if any)
39    data_path: Option<PathBuf>,
40}
41
42/// OpenAI chat message format
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct OpenAIChatMessage {
45    pub role: String,
46    pub content: String,
47}
48
49/// OpenAI chat completion request
50#[derive(Debug, Serialize)]
51pub struct OpenAIChatRequest {
52    pub model: String,
53    pub messages: Vec<OpenAIChatMessage>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub temperature: Option<f32>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub max_tokens: Option<u32>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub stream: Option<bool>,
60}
61
62/// OpenAI chat completion response
63#[derive(Debug, Deserialize)]
64pub struct OpenAIChatResponse {
65    pub id: String,
66    pub choices: Vec<OpenAIChatChoice>,
67    #[allow(dead_code)]
68    pub model: String,
69}
70
71/// OpenAI chat completion choice
72#[derive(Debug, Deserialize)]
73pub struct OpenAIChatChoice {
74    pub message: OpenAIChatMessage,
75    #[allow(dead_code)]
76    pub finish_reason: Option<String>,
77}
78
79impl OpenAICompatProvider {
80    /// Create a new OpenAI-compatible provider
81    pub fn new(
82        provider_type: ProviderType,
83        name: impl Into<String>,
84        endpoint: impl Into<String>,
85    ) -> Self {
86        let endpoint = endpoint.into();
87        Self {
88            provider_type,
89            name: name.into(),
90            endpoint: endpoint.clone(),
91            api_key: None,
92            model: None,
93            available: Self::check_availability(&endpoint),
94            data_path: None,
95        }
96    }
97
98    /// Set API key
99    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
100        self.api_key = Some(api_key.into());
101        self
102    }
103
104    /// Set default model
105    pub fn with_model(mut self, model: impl Into<String>) -> Self {
106        self.model = Some(model.into());
107        self
108    }
109
110    /// Set local data path
111    pub fn with_data_path(mut self, path: PathBuf) -> Self {
112        self.data_path = Some(path);
113        self
114    }
115
116    /// Check if the endpoint is available
117    fn check_availability(endpoint: &str) -> bool {
118        // Basic check - would use HTTP client in production
119        !endpoint.is_empty()
120    }
121
122    /// Convert CSM session to OpenAI message format
123    pub fn session_to_messages(session: &ChatSession) -> Vec<OpenAIChatMessage> {
124        let mut messages = Vec::new();
125
126        for request in &session.requests {
127            // Add user message
128            if let Some(msg) = &request.message {
129                if let Some(text) = &msg.text {
130                    messages.push(OpenAIChatMessage {
131                        role: "user".to_string(),
132                        content: text.clone(),
133                    });
134                }
135            }
136
137            // Add assistant response
138            if let Some(response) = &request.response {
139                if let Some(text) = extract_response_text(response) {
140                    messages.push(OpenAIChatMessage {
141                        role: "assistant".to_string(),
142                        content: text,
143                    });
144                }
145            }
146        }
147
148        messages
149    }
150
151    /// Convert OpenAI messages to CSM session
152    pub fn messages_to_session(
153        messages: Vec<OpenAIChatMessage>,
154        model: &str,
155        provider_name: &str,
156    ) -> ChatSession {
157        let now = chrono::Utc::now().timestamp_millis();
158        let session_id = uuid::Uuid::new_v4().to_string();
159
160        let mut requests = Vec::new();
161        let mut user_msg: Option<String> = None;
162
163        for msg in messages {
164            match msg.role.as_str() {
165                "user" => {
166                    user_msg = Some(msg.content);
167                }
168                "assistant" => {
169                    if let Some(user_text) = user_msg.take() {
170                        requests.push(ChatRequest {
171                            timestamp: Some(now),
172                            message: Some(ChatMessage {
173                                text: Some(user_text),
174                                parts: None,
175                            }),
176                            response: Some(serde_json::json!({
177                                "value": [{"value": msg.content}]
178                            })),
179                            variable_data: None,
180                            request_id: Some(uuid::Uuid::new_v4().to_string()),
181                            response_id: Some(uuid::Uuid::new_v4().to_string()),
182                            model_id: Some(model.to_string()),
183                            agent: None,
184                            result: None,
185                            followups: None,
186                            is_canceled: Some(false),
187                            content_references: None,
188                            code_citations: None,
189                            response_markdown_info: None,
190                            source_session: None,
191                            model_state: None,
192                            time_spent_waiting: None,
193                        });
194                    }
195                }
196                "system" => {
197                    // System messages could be stored as metadata
198                }
199                _ => {}
200            }
201        }
202
203        ChatSession {
204            version: 3,
205            session_id: Some(session_id),
206            creation_date: now,
207            last_message_date: now,
208            is_imported: true,
209            initial_location: "api".to_string(),
210            custom_title: Some(format!("{} Chat", provider_name)),
211            requester_username: Some("user".to_string()),
212            requester_avatar_icon_uri: None,
213            responder_username: Some(format!("{}/{}", provider_name, model)),
214            responder_avatar_icon_uri: None,
215            requests,
216        }
217    }
218}
219
220impl ChatProvider for OpenAICompatProvider {
221    fn provider_type(&self) -> ProviderType {
222        self.provider_type
223    }
224
225    fn name(&self) -> &str {
226        &self.name
227    }
228
229    fn is_available(&self) -> bool {
230        self.available
231    }
232
233    fn sessions_path(&self) -> Option<PathBuf> {
234        self.data_path.clone()
235    }
236
237    fn list_sessions(&self) -> Result<Vec<ChatSession>> {
238        // OpenAI-compatible APIs don't persist sessions
239        // This would need a local history storage layer
240        Ok(Vec::new())
241    }
242
243    fn import_session(&self, _session_id: &str) -> Result<ChatSession> {
244        anyhow::bail!("{} does not persist chat sessions", self.name)
245    }
246
247    fn export_session(&self, _session: &ChatSession) -> Result<()> {
248        // Could implement by sending messages to recreate context
249        anyhow::bail!("Export to {} not yet implemented", self.name)
250    }
251}
252
253/// Discover available OpenAI-compatible providers
254pub fn discover_openai_compatible_providers() -> Vec<OpenAICompatProvider> {
255    let mut providers = Vec::new();
256
257    // vLLM (default port 8000)
258    if let Some(provider) = discover_vllm() {
259        providers.push(provider);
260    }
261
262    // LM Studio (default port 1234)
263    if let Some(provider) = discover_lm_studio() {
264        providers.push(provider);
265    }
266
267    // LocalAI (default port 8080)
268    if let Some(provider) = discover_localai() {
269        providers.push(provider);
270    }
271
272    // Text Generation WebUI (default port 5000)
273    if let Some(provider) = discover_text_gen_webui() {
274        providers.push(provider);
275    }
276
277    // Jan.ai (default port 1337)
278    if let Some(provider) = discover_jan() {
279        providers.push(provider);
280    }
281
282    // GPT4All (default port 4891)
283    if let Some(provider) = discover_gpt4all() {
284        providers.push(provider);
285    }
286
287    // Azure AI Foundry / Foundry Local (default port 5272)
288    if let Some(provider) = discover_foundry() {
289        providers.push(provider);
290    }
291
292    providers
293}
294
295fn discover_vllm() -> Option<OpenAICompatProvider> {
296    let endpoint =
297        std::env::var("VLLM_ENDPOINT").unwrap_or_else(|_| "http://localhost:8000/v1".to_string());
298
299    Some(OpenAICompatProvider::new(
300        ProviderType::Vllm,
301        "vLLM",
302        endpoint,
303    ))
304}
305
306fn discover_lm_studio() -> Option<OpenAICompatProvider> {
307    let endpoint = std::env::var("LM_STUDIO_ENDPOINT")
308        .unwrap_or_else(|_| "http://localhost:1234/v1".to_string());
309
310    // Check for LM Studio data directory
311    let data_path = find_lm_studio_data();
312
313    let mut provider = OpenAICompatProvider::new(ProviderType::LmStudio, "LM Studio", endpoint);
314
315    if let Some(path) = data_path {
316        provider = provider.with_data_path(path);
317    }
318
319    Some(provider)
320}
321
322fn discover_localai() -> Option<OpenAICompatProvider> {
323    let endpoint = std::env::var("LOCALAI_ENDPOINT")
324        .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
325
326    Some(OpenAICompatProvider::new(
327        ProviderType::LocalAI,
328        "LocalAI",
329        endpoint,
330    ))
331}
332
333fn discover_text_gen_webui() -> Option<OpenAICompatProvider> {
334    let endpoint = std::env::var("TEXT_GEN_WEBUI_ENDPOINT")
335        .unwrap_or_else(|_| "http://localhost:5000/v1".to_string());
336
337    Some(OpenAICompatProvider::new(
338        ProviderType::TextGenWebUI,
339        "Text Generation WebUI",
340        endpoint,
341    ))
342}
343
344fn discover_jan() -> Option<OpenAICompatProvider> {
345    let endpoint =
346        std::env::var("JAN_ENDPOINT").unwrap_or_else(|_| "http://localhost:1337/v1".to_string());
347
348    // Check for Jan data directory
349    let data_path = find_jan_data();
350
351    let mut provider = OpenAICompatProvider::new(ProviderType::Jan, "Jan.ai", endpoint);
352
353    if let Some(path) = data_path {
354        provider = provider.with_data_path(path);
355    }
356
357    Some(provider)
358}
359
360fn discover_gpt4all() -> Option<OpenAICompatProvider> {
361    let endpoint = std::env::var("GPT4ALL_ENDPOINT")
362        .unwrap_or_else(|_| "http://localhost:4891/v1".to_string());
363
364    // Check for GPT4All data directory
365    let data_path = find_gpt4all_data();
366
367    let mut provider = OpenAICompatProvider::new(ProviderType::Gpt4All, "GPT4All", endpoint);
368
369    if let Some(path) = data_path {
370        provider = provider.with_data_path(path);
371    }
372
373    Some(provider)
374}
375
376fn discover_foundry() -> Option<OpenAICompatProvider> {
377    // Azure AI Foundry Local / Foundry Local
378    let endpoint = std::env::var("FOUNDRY_LOCAL_ENDPOINT")
379        .or_else(|_| std::env::var("AI_FOUNDRY_ENDPOINT"))
380        .unwrap_or_else(|_| "http://localhost:5272/v1".to_string());
381
382    Some(OpenAICompatProvider::new(
383        ProviderType::Foundry,
384        "Azure AI Foundry",
385        endpoint,
386    ))
387}
388
389// Helper functions to find application data directories
390
391fn find_lm_studio_data() -> Option<PathBuf> {
392    #[cfg(target_os = "windows")]
393    {
394        let home = dirs::home_dir()?;
395        let path = home.join(".cache").join("lm-studio");
396        if path.exists() {
397            return Some(path);
398        }
399    }
400
401    #[cfg(target_os = "macos")]
402    {
403        let home = dirs::home_dir()?;
404        let path = home.join(".cache").join("lm-studio");
405        if path.exists() {
406            return Some(path);
407        }
408    }
409
410    #[cfg(target_os = "linux")]
411    {
412        if let Some(cache_dir) = dirs::cache_dir() {
413            let path = cache_dir.join("lm-studio");
414            if path.exists() {
415                return Some(path);
416            }
417        }
418    }
419
420    None
421}
422
423fn find_jan_data() -> Option<PathBuf> {
424    #[cfg(target_os = "windows")]
425    {
426        let home = dirs::home_dir()?;
427        let path = home.join("jan");
428        if path.exists() {
429            return Some(path);
430        }
431    }
432
433    #[cfg(target_os = "macos")]
434    {
435        let home = dirs::home_dir()?;
436        let path = home.join("jan");
437        if path.exists() {
438            return Some(path);
439        }
440    }
441
442    #[cfg(target_os = "linux")]
443    {
444        let home = dirs::home_dir()?;
445        let path = home.join("jan");
446        if path.exists() {
447            return Some(path);
448        }
449    }
450
451    None
452}
453
454fn find_gpt4all_data() -> Option<PathBuf> {
455    #[cfg(target_os = "windows")]
456    {
457        let local_app_data = dirs::data_local_dir()?;
458        let path = local_app_data.join("nomic.ai").join("GPT4All");
459        if path.exists() {
460            return Some(path);
461        }
462    }
463
464    #[cfg(target_os = "macos")]
465    {
466        let home = dirs::home_dir()?;
467        let path = home
468            .join("Library")
469            .join("Application Support")
470            .join("nomic.ai")
471            .join("GPT4All");
472        if path.exists() {
473            return Some(path);
474        }
475    }
476
477    #[cfg(target_os = "linux")]
478    {
479        if let Some(data_dir) = dirs::data_dir() {
480            let path = data_dir.join("nomic.ai").join("GPT4All");
481            if path.exists() {
482                return Some(path);
483            }
484        }
485    }
486
487    None
488}
489
490/// Extract text from various response formats
491fn extract_response_text(response: &serde_json::Value) -> Option<String> {
492    // Try direct text field
493    if let Some(text) = response.get("text").and_then(|v| v.as_str()) {
494        return Some(text.to_string());
495    }
496
497    // Try value array format (VS Code Copilot format)
498    if let Some(value) = response.get("value").and_then(|v| v.as_array()) {
499        let parts: Vec<String> = value
500            .iter()
501            .filter_map(|v| v.get("value").and_then(|v| v.as_str()))
502            .map(String::from)
503            .collect();
504        if !parts.is_empty() {
505            return Some(parts.join("\n"));
506        }
507    }
508
509    // Try content field (OpenAI format)
510    if let Some(content) = response.get("content").and_then(|v| v.as_str()) {
511        return Some(content.to_string());
512    }
513
514    None
515}