bytedocs_rs/ai/
client.rs

1use crate::ai::types::{register_client_factory, AIConfig, Client, ChatRequest, ChatResponse};
2use anyhow::Result;
3use serde_json::json;
4
5pub struct OpenAIClient {
6    config: AIConfig,
7}
8
9pub struct GeminiClient {
10    config: AIConfig,
11}
12
13pub struct OpenRouterClient {
14    config: AIConfig,
15}
16
17impl OpenAIClient {
18    pub fn new(config: &AIConfig) -> Result<Box<dyn Client>> {
19        Ok(Box::new(Self {
20            config: config.clone(),
21        }))
22    }
23
24    // Optimize API context to ultra-compact format like Laravel
25    fn optimize_api_context(&self, context: &str, user_message: &str) -> String {
26        // Parse context to extract API info
27        let (api_info, endpoints) = self.parse_context(context);
28
29        // Find relevant endpoints based on user question
30        let relevant_endpoints = self.find_relevant_endpoints(&endpoints, user_message);
31
32        // Format in ultra-compact way
33        let mut result = format!("API: {} | {}", api_info.0, api_info.1); // title | base_url
34
35        // Add compact endpoints (max 5 to save tokens)
36        let max_endpoints = 5;
37        for (i, endpoint) in relevant_endpoints.iter().take(max_endpoints).enumerate() {
38            if i == 0 {
39                result.push('\n');
40            }
41            result.push_str(&self.format_compact_endpoint(endpoint));
42        }
43
44        // Add response format note
45        result.push_str("\nRESP: {success:bool,message:str,data:obj}");
46        result.push_str("\nRULES: Use only listed endpoints. No invention.");
47        result
48    }
49
50    // Parse context to extract API info and endpoints
51    fn parse_context(&self, context: &str) -> ((String, String), Vec<CompactEndpoint>) {
52        let mut api_title = "API".to_string();
53        let mut base_url = "localhost:8000".to_string();
54        let mut endpoints = Vec::new();
55
56        // Extract API title and base URL
57        if let Some(title_start) = context.find("API Title: ") {
58            if let Some(title_end) = context[title_start + 11..].find('\n') {
59                api_title = context[title_start + 11..title_start + 11 + title_end].trim().to_string();
60            }
61        }
62
63        if let Some(url_start) = context.find("Base URLs: ") {
64            if let Some(url_end) = context[url_start..].find(',') {
65                let url_part = &context[url_start..url_start + url_end];
66                if let Some(first_url) = url_part.split("Production: ").nth(1) {
67                    base_url = first_url.trim().to_string();
68                }
69            }
70        }
71
72        // Parse endpoints from OpenAPI paths
73        if let Some(paths_start) = context.find("\"paths\": {") {
74            endpoints = self.extract_endpoints_from_openapi(&context[paths_start..]);
75        }
76
77        ((api_title, base_url), endpoints)
78    }
79
80    // Extract endpoints from OpenAPI JSON in context
81    fn extract_endpoints_from_openapi(&self, paths_section: &str) -> Vec<CompactEndpoint> {
82        let mut endpoints = Vec::new();
83
84        // Enhanced parsing with request/response structures
85        let endpoint_definitions = [
86            ("/api/health", "GET", "Health check", vec![], vec![]),
87            ("/api/users", "GET", "List users", vec!["page", "limit", "status", "search"], vec![]),
88            ("/api/users", "POST", "Create user", vec![], vec!["name*", "email*", "age"]),
89            ("/api/users/{id}", "GET", "Get user", vec!["id*"], vec![]),
90            ("/api/users/{id}", "PUT", "Update user", vec!["id*"], vec!["name", "email", "age", "status"]),
91            ("/api/users/{id}", "DELETE", "Delete user", vec!["id*"], vec![]),
92            ("/api/products", "GET", "List products", vec![], vec![]),
93            ("/api/products", "POST", "Create product", vec![], vec!["name*", "price*", "category*"]),
94        ];
95
96        for (path, method, desc, params, body_fields) in endpoint_definitions.iter() {
97            if paths_section.contains(&format!("\"{path}\"")) && paths_section.contains(&format!("\"{}\": {{", method.to_lowercase())) {
98                endpoints.push(CompactEndpoint {
99                    path: path.to_string(),
100                    method: method.to_string(),
101                    description: desc.to_string(),
102                    params: params.iter().map(|s| s.to_string()).collect(),
103                    body_fields: body_fields.iter().map(|s| s.to_string()).collect(),
104                });
105            }
106        }
107
108        endpoints
109    }
110
111    // Find endpoints relevant to user question
112    fn find_relevant_endpoints<'a>(&self, endpoints: &'a [CompactEndpoint], user_message: &str) -> Vec<&'a CompactEndpoint> {
113        let message = user_message.to_lowercase();
114        let mut scored_endpoints = Vec::new();
115
116        for endpoint in endpoints {
117            let mut score = 0;
118
119            // Score based on path keywords
120            let path_lower = endpoint.path.to_lowercase();
121            if message.contains("user") && path_lower.contains("user") { score += 10; }
122            if message.contains("product") && path_lower.contains("product") { score += 10; }
123            if message.contains("health") && path_lower.contains("health") { score += 10; }
124
125            // Score based on method
126            if message.contains("create") && endpoint.method == "POST" { score += 8; }
127            if message.contains("update") && endpoint.method == "PUT" { score += 8; }
128            if message.contains("delete") && endpoint.method == "DELETE" { score += 8; }
129            if message.contains("get") || message.contains("list") {
130                if endpoint.method == "GET" { score += 8; }
131            }
132
133            // Score based on description
134            let desc_lower = endpoint.description.to_lowercase();
135            for word in message.split_whitespace() {
136                if desc_lower.contains(word) { score += 3; }
137            }
138
139            scored_endpoints.push((endpoint, score));
140        }
141
142        // Sort by score and return top endpoints
143        scored_endpoints.sort_by(|a, b| b.1.cmp(&a.1));
144
145        // If no relevant endpoints found, return first few
146        if scored_endpoints.is_empty() || scored_endpoints[0].1 == 0 {
147            return endpoints.iter().take(5).collect();
148        }
149
150        scored_endpoints.iter().map(|(endpoint, _)| *endpoint).collect()
151    }
152
153    // Format endpoint in compact but informative format
154    fn format_compact_endpoint(&self, endpoint: &CompactEndpoint) -> String {
155        let mut result = format!("{} {}", endpoint.method, endpoint.path);
156
157        // Add request body structure for POST/PUT methods
158        if !endpoint.body_fields.is_empty() && (endpoint.method == "POST" || endpoint.method == "PUT") {
159            result.push_str(" {");
160            result.push_str(&endpoint.body_fields.join(","));
161            result.push('}');
162        }
163
164        // Add parameters for GET methods or path parameters
165        if !endpoint.params.is_empty() {
166            result.push_str(" ?");
167            result.push_str(&endpoint.params.join(","));
168        }
169
170        result.push('\n');
171        result
172    }
173}
174
175// Simple struct for compact endpoint representation
176#[derive(Debug)]
177struct CompactEndpoint {
178    path: String,
179    method: String,
180    description: String,
181    params: Vec<String>,
182    body_fields: Vec<String>,
183}
184
185#[async_trait::async_trait]
186impl Client for OpenAIClient {
187    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
188        let client = reqwest::Client::new();
189
190        // Build messages for OpenAI API
191        let mut messages = Vec::new();
192
193        if let Some(ref context) = request.context {
194            if !context.is_empty() {
195                // Extract and optimize context in Laravel-style compact format
196                let optimized_context = self.optimize_api_context(context, &request.message);
197
198                messages.push(json!({
199                    "role": "system",
200                    "content": optimized_context
201                }));
202            } else {
203                // Add basic system message if no context
204                messages.push(json!({
205                    "role": "system",
206                    "content": "API assistant. Use only documented endpoints."
207                }));
208            }
209        } else {
210            // Add basic system message if no context
211            messages.push(json!({
212                "role": "system",
213                "content": "API assistant. Use only documented endpoints."
214            }));
215        }
216
217        messages.push(json!({
218            "role": "user",
219            "content": request.message
220        }));
221
222        // Prepare request body
223        let mut body = json!({
224            "model": self.config.features.model,
225            "messages": messages
226        });
227
228        // Handle temperature based on model type
229        if self.config.features.model.starts_with("gpt-5") {
230            // GPT-5 models (nano, mini, standard) only support integer temperature of 1
231            body["temperature"] = json!(1);
232        } else {
233            // Other models use floating point temperature
234            body["temperature"] = json!(self.config.features.temperature as f64 / 10.0);
235        }
236
237        // Add token limits with special handling for GPT-5 models
238        if self.config.features.model.starts_with("gpt-5") {
239            // GPT-5 models need higher token limits to account for reasoning tokens
240            let completion_tokens = if self.config.features.max_completion_tokens > 0 {
241                std::cmp::max(self.config.features.max_completion_tokens, 2000)
242            } else {
243                2000
244            };
245            body["max_completion_tokens"] = json!(completion_tokens);
246        } else {
247            // Other models use configured limits
248            if self.config.features.max_completion_tokens > 0 {
249                body["max_completion_tokens"] = json!(self.config.features.max_completion_tokens);
250            } else if self.config.features.max_tokens > 0 {
251                body["max_tokens"] = json!(self.config.features.max_tokens);
252            }
253        }
254
255        // Make API call to OpenAI
256        let response = client
257            .post("https://api.openai.com/v1/chat/completions")
258            .header("Authorization", format!("Bearer {}", self.config.api_key))
259            .header("Content-Type", "application/json")
260            .json(&body)
261            .send()
262            .await
263            .map_err(|e| anyhow::anyhow!("Failed to send request to OpenAI: {}", e))?;
264
265        if !response.status().is_success() {
266            let error_text = response.text().await.unwrap_or_default();
267            return Ok(ChatResponse {
268                response: "".to_string(),
269                provider: "openai".to_string(),
270                model: self.config.features.model.clone(),
271                tokens_used: 0,
272                error: format!("OpenAI API error: {}", error_text),
273            });
274        }
275
276        // Parse response
277        let response_json: serde_json::Value = response
278            .json()
279            .await
280            .map_err(|e| anyhow::anyhow!("Failed to parse OpenAI response: {}", e))?;
281
282        // Extract content from response
283        let choices = response_json
284            .get("choices")
285            .and_then(|c| c.as_array())
286            .ok_or_else(|| anyhow::anyhow!("No choices in OpenAI response"))?;
287
288        if choices.is_empty() {
289            return Ok(ChatResponse {
290                response: "".to_string(),
291                provider: "openai".to_string(),
292                model: self.config.features.model.clone(),
293                tokens_used: 0,
294                error: "No response choices returned from OpenAI".to_string(),
295            });
296        }
297
298        let choice = &choices[0];
299        let message = choice.get("message");
300        let finish_reason = choice.get("finish_reason");
301
302
303        let content = message
304            .and_then(|m| m.get("content"))
305            .and_then(|c| c.as_str())
306            .unwrap_or("")
307            .to_string();
308
309
310        // Extract token usage if available (do this before early returns)
311        let tokens_used = response_json
312            .get("usage")
313            .and_then(|u| u.get("total_tokens"))
314            .and_then(|t| t.as_i64())
315            .unwrap_or(0) as i32;
316
317        // Extract model used
318        let model_used = response_json
319            .get("model")
320            .and_then(|m| m.as_str())
321            .unwrap_or(&self.config.features.model)
322            .to_string();
323
324        // Check if the response was cut off due to length limit or content filtering
325        if let Some(reason) = finish_reason.and_then(|r| r.as_str()) {
326            if reason == "length" {
327                // For GPT-5 models with empty content due to length, provide a helpful message
328                if content.is_empty() && self.config.features.model.starts_with("gpt-5") {
329                    return Ok(ChatResponse {
330                        response: "I understand your question, but my response was truncated due to token limits. Could you please ask a more specific question about the API?".to_string(),
331                        provider: "openai".to_string(),
332                        model: model_used,
333                        tokens_used,
334                        error: String::new(),
335                    });
336                }
337            }
338        }
339
340        Ok(ChatResponse {
341            response: content,
342            provider: "openai".to_string(),
343            model: model_used,
344            tokens_used,
345            error: String::new(),
346        })
347    }
348
349    fn get_provider(&self) -> &str {
350        "openai"
351    }
352
353    fn get_model(&self) -> &str {
354        &self.config.features.model
355    }
356}
357
358impl GeminiClient {
359    pub fn new(config: &AIConfig) -> Result<Box<dyn Client>> {
360        Ok(Box::new(Self {
361            config: config.clone(),
362        }))
363    }
364}
365
366#[async_trait::async_trait]
367impl Client for GeminiClient {
368    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
369        // TODO: Implement actual Gemini API call
370        Ok(ChatResponse {
371            response: format!("Gemini response to: {}", request.message),
372            provider: "gemini".to_string(),
373            model: self.config.features.model.clone(),
374            tokens_used: 100,
375            error: String::new(),
376        })
377    }
378
379    fn get_provider(&self) -> &str {
380        "gemini"
381    }
382
383    fn get_model(&self) -> &str {
384        &self.config.features.model
385    }
386}
387
388impl OpenRouterClient {
389    pub fn new(config: &AIConfig) -> Result<Box<dyn Client>> {
390        if config.api_key.is_empty() {
391            return Err(anyhow::anyhow!("OpenRouter API key is required"));
392        }
393        Ok(Box::new(Self {
394            config: config.clone(),
395        }))
396    }
397
398    fn build_system_prompt(&self, request: &ChatRequest) -> String {
399        let mut prompt = String::from("You are a helpful AI assistant that provides information about APIs. ");
400
401        if let Some(ref context) = request.context {
402            if !context.is_empty() {
403                prompt.push_str("Here's the API documentation context:\n\n");
404                prompt.push_str(context);
405                prompt.push_str("\n\nPlease help the user understand this API based on the provided documentation.");
406            }
407        }
408
409        prompt
410    }
411}
412
413#[async_trait::async_trait]
414impl Client for OpenRouterClient {
415    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
416        let client = reqwest::Client::new();
417
418        // Build messages for OpenAI-compatible API
419        let system_prompt = self.build_system_prompt(&request);
420        let messages = vec![
421            json!({
422                "role": "system",
423                "content": system_prompt
424            }),
425            json!({
426                "role": "user",
427                "content": request.message
428            })
429        ];
430
431        // Prepare request body
432        let mut body = json!({
433            "model": self.config.features.model,
434            "messages": messages
435        });
436
437        // Add optional parameters
438        if self.config.features.max_tokens > 0 {
439            body["max_tokens"] = json!(self.config.features.max_tokens);
440        }
441        if self.config.features.max_completion_tokens > 0 {
442            body["max_completion_tokens"] = json!(self.config.features.max_completion_tokens);
443        }
444        if self.config.features.temperature > 0 {
445            body["temperature"] = json!(self.config.features.temperature);
446        }
447
448        // Make API call to OpenRouter
449        let response = client
450            .post("https://openrouter.ai/api/v1/chat/completions")
451            .header("Authorization", format!("Bearer {}", self.config.api_key))
452            .header("Content-Type", "application/json")
453            .header("HTTP-Referer", "https://bytedocs.rs") // Required for OpenRouter
454            .header("X-Title", "ByteDocs") // Required for OpenRouter
455            .json(&body)
456            .send()
457            .await
458            .map_err(|e| anyhow::anyhow!("Failed to send request to OpenRouter: {}", e))?;
459
460        if !response.status().is_success() {
461            let error_text = response.text().await.unwrap_or_default();
462            return Ok(ChatResponse {
463                response: "".to_string(),
464                provider: "openrouter".to_string(),
465                model: self.config.features.model.clone(),
466                tokens_used: 0,
467                error: format!("OpenRouter API error: {}", error_text),
468            });
469        }
470
471        // Parse response
472        let response_json: serde_json::Value = response
473            .json()
474            .await
475            .map_err(|e| anyhow::anyhow!("Failed to parse OpenRouter response: {}", e))?;
476
477        // Extract content from response
478        let choices = response_json
479            .get("choices")
480            .and_then(|c| c.as_array())
481            .ok_or_else(|| anyhow::anyhow!("No choices in OpenRouter response"))?;
482
483        if choices.is_empty() {
484            return Ok(ChatResponse {
485                response: "".to_string(),
486                provider: "openrouter".to_string(),
487                model: self.config.features.model.clone(),
488                tokens_used: 0,
489                error: "No response choices returned from OpenRouter".to_string(),
490            });
491        }
492
493        let content = choices[0]
494            .get("message")
495            .and_then(|m| m.get("content"))
496            .and_then(|c| c.as_str())
497            .unwrap_or("")
498            .to_string();
499
500        // Extract token usage if available
501        let tokens_used = response_json
502            .get("usage")
503            .and_then(|u| u.get("total_tokens"))
504            .and_then(|t| t.as_i64())
505            .unwrap_or(0) as i32;
506
507        // Extract model used
508        let model_used = response_json
509            .get("model")
510            .and_then(|m| m.as_str())
511            .unwrap_or(&self.config.features.model)
512            .to_string();
513
514        Ok(ChatResponse {
515            response: content,
516            provider: "openrouter".to_string(),
517            model: model_used,
518            tokens_used,
519            error: String::new(),
520        })
521    }
522
523    fn get_provider(&self) -> &str {
524        "openrouter"
525    }
526
527    fn get_model(&self) -> &str {
528        &self.config.features.model
529    }
530}
531
532// Register all client factories
533pub fn init_client_factories() {
534    register_client_factory("openai", |config| OpenAIClient::new(config));
535    register_client_factory("gemini", |config| GeminiClient::new(config));
536    register_client_factory("openrouter", |config| OpenRouterClient::new(config));
537}