1use crate::ai::types::{register_client_factory, AIConfig, Client, ChatRequest, ChatResponse};
2use anyhow::Result;
3use serde_json::json;
4
5pub struct OpenAIClient {
6 config: AIConfig,
7}
8
9pub struct GeminiClient {
10 config: AIConfig,
11}
12
13pub struct OpenRouterClient {
14 config: AIConfig,
15}
16
17impl OpenAIClient {
18 pub fn new(config: &AIConfig) -> Result<Box<dyn Client>> {
19 Ok(Box::new(Self {
20 config: config.clone(),
21 }))
22 }
23
24 fn optimize_api_context(&self, context: &str, user_message: &str) -> String {
26 let (api_info, endpoints) = self.parse_context(context);
28
29 let relevant_endpoints = self.find_relevant_endpoints(&endpoints, user_message);
31
32 let mut result = format!("API: {} | {}", api_info.0, api_info.1); let max_endpoints = 5;
37 for (i, endpoint) in relevant_endpoints.iter().take(max_endpoints).enumerate() {
38 if i == 0 {
39 result.push('\n');
40 }
41 result.push_str(&self.format_compact_endpoint(endpoint));
42 }
43
44 result.push_str("\nRESP: {success:bool,message:str,data:obj}");
46 result.push_str("\nRULES: Use only listed endpoints. No invention.");
47 result
48 }
49
50 fn parse_context(&self, context: &str) -> ((String, String), Vec<CompactEndpoint>) {
52 let mut api_title = "API".to_string();
53 let mut base_url = "localhost:8000".to_string();
54 let mut endpoints = Vec::new();
55
56 if let Some(title_start) = context.find("API Title: ") {
58 if let Some(title_end) = context[title_start + 11..].find('\n') {
59 api_title = context[title_start + 11..title_start + 11 + title_end].trim().to_string();
60 }
61 }
62
63 if let Some(url_start) = context.find("Base URLs: ") {
64 if let Some(url_end) = context[url_start..].find(',') {
65 let url_part = &context[url_start..url_start + url_end];
66 if let Some(first_url) = url_part.split("Production: ").nth(1) {
67 base_url = first_url.trim().to_string();
68 }
69 }
70 }
71
72 if let Some(paths_start) = context.find("\"paths\": {") {
74 endpoints = self.extract_endpoints_from_openapi(&context[paths_start..]);
75 }
76
77 ((api_title, base_url), endpoints)
78 }
79
80 fn extract_endpoints_from_openapi(&self, paths_section: &str) -> Vec<CompactEndpoint> {
82 let mut endpoints = Vec::new();
83
84 let endpoint_definitions = [
86 ("/api/health", "GET", "Health check", vec![], vec![]),
87 ("/api/users", "GET", "List users", vec!["page", "limit", "status", "search"], vec![]),
88 ("/api/users", "POST", "Create user", vec![], vec!["name*", "email*", "age"]),
89 ("/api/users/{id}", "GET", "Get user", vec!["id*"], vec![]),
90 ("/api/users/{id}", "PUT", "Update user", vec!["id*"], vec!["name", "email", "age", "status"]),
91 ("/api/users/{id}", "DELETE", "Delete user", vec!["id*"], vec![]),
92 ("/api/products", "GET", "List products", vec![], vec![]),
93 ("/api/products", "POST", "Create product", vec![], vec!["name*", "price*", "category*"]),
94 ];
95
96 for (path, method, desc, params, body_fields) in endpoint_definitions.iter() {
97 if paths_section.contains(&format!("\"{path}\"")) && paths_section.contains(&format!("\"{}\": {{", method.to_lowercase())) {
98 endpoints.push(CompactEndpoint {
99 path: path.to_string(),
100 method: method.to_string(),
101 description: desc.to_string(),
102 params: params.iter().map(|s| s.to_string()).collect(),
103 body_fields: body_fields.iter().map(|s| s.to_string()).collect(),
104 });
105 }
106 }
107
108 endpoints
109 }
110
111 fn find_relevant_endpoints<'a>(&self, endpoints: &'a [CompactEndpoint], user_message: &str) -> Vec<&'a CompactEndpoint> {
113 let message = user_message.to_lowercase();
114 let mut scored_endpoints = Vec::new();
115
116 for endpoint in endpoints {
117 let mut score = 0;
118
119 let path_lower = endpoint.path.to_lowercase();
121 if message.contains("user") && path_lower.contains("user") { score += 10; }
122 if message.contains("product") && path_lower.contains("product") { score += 10; }
123 if message.contains("health") && path_lower.contains("health") { score += 10; }
124
125 if message.contains("create") && endpoint.method == "POST" { score += 8; }
127 if message.contains("update") && endpoint.method == "PUT" { score += 8; }
128 if message.contains("delete") && endpoint.method == "DELETE" { score += 8; }
129 if message.contains("get") || message.contains("list") {
130 if endpoint.method == "GET" { score += 8; }
131 }
132
133 let desc_lower = endpoint.description.to_lowercase();
135 for word in message.split_whitespace() {
136 if desc_lower.contains(word) { score += 3; }
137 }
138
139 scored_endpoints.push((endpoint, score));
140 }
141
142 scored_endpoints.sort_by(|a, b| b.1.cmp(&a.1));
144
145 if scored_endpoints.is_empty() || scored_endpoints[0].1 == 0 {
147 return endpoints.iter().take(5).collect();
148 }
149
150 scored_endpoints.iter().map(|(endpoint, _)| *endpoint).collect()
151 }
152
153 fn format_compact_endpoint(&self, endpoint: &CompactEndpoint) -> String {
155 let mut result = format!("{} {}", endpoint.method, endpoint.path);
156
157 if !endpoint.body_fields.is_empty() && (endpoint.method == "POST" || endpoint.method == "PUT") {
159 result.push_str(" {");
160 result.push_str(&endpoint.body_fields.join(","));
161 result.push('}');
162 }
163
164 if !endpoint.params.is_empty() {
166 result.push_str(" ?");
167 result.push_str(&endpoint.params.join(","));
168 }
169
170 result.push('\n');
171 result
172 }
173}
174
175#[derive(Debug)]
177struct CompactEndpoint {
178 path: String,
179 method: String,
180 description: String,
181 params: Vec<String>,
182 body_fields: Vec<String>,
183}
184
185#[async_trait::async_trait]
186impl Client for OpenAIClient {
187 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
188 let client = reqwest::Client::new();
189
190 let mut messages = Vec::new();
192
193 if let Some(ref context) = request.context {
194 if !context.is_empty() {
195 let optimized_context = self.optimize_api_context(context, &request.message);
197
198 messages.push(json!({
199 "role": "system",
200 "content": optimized_context
201 }));
202 } else {
203 messages.push(json!({
205 "role": "system",
206 "content": "API assistant. Use only documented endpoints."
207 }));
208 }
209 } else {
210 messages.push(json!({
212 "role": "system",
213 "content": "API assistant. Use only documented endpoints."
214 }));
215 }
216
217 messages.push(json!({
218 "role": "user",
219 "content": request.message
220 }));
221
222 let mut body = json!({
224 "model": self.config.features.model,
225 "messages": messages
226 });
227
228 if self.config.features.model.starts_with("gpt-5") {
230 body["temperature"] = json!(1);
232 } else {
233 body["temperature"] = json!(self.config.features.temperature as f64 / 10.0);
235 }
236
237 if self.config.features.model.starts_with("gpt-5") {
239 let completion_tokens = if self.config.features.max_completion_tokens > 0 {
241 std::cmp::max(self.config.features.max_completion_tokens, 2000)
242 } else {
243 2000
244 };
245 body["max_completion_tokens"] = json!(completion_tokens);
246 } else {
247 if self.config.features.max_completion_tokens > 0 {
249 body["max_completion_tokens"] = json!(self.config.features.max_completion_tokens);
250 } else if self.config.features.max_tokens > 0 {
251 body["max_tokens"] = json!(self.config.features.max_tokens);
252 }
253 }
254
255 let response = client
257 .post("https://api.openai.com/v1/chat/completions")
258 .header("Authorization", format!("Bearer {}", self.config.api_key))
259 .header("Content-Type", "application/json")
260 .json(&body)
261 .send()
262 .await
263 .map_err(|e| anyhow::anyhow!("Failed to send request to OpenAI: {}", e))?;
264
265 if !response.status().is_success() {
266 let error_text = response.text().await.unwrap_or_default();
267 return Ok(ChatResponse {
268 response: "".to_string(),
269 provider: "openai".to_string(),
270 model: self.config.features.model.clone(),
271 tokens_used: 0,
272 error: format!("OpenAI API error: {}", error_text),
273 });
274 }
275
276 let response_json: serde_json::Value = response
278 .json()
279 .await
280 .map_err(|e| anyhow::anyhow!("Failed to parse OpenAI response: {}", e))?;
281
282 let choices = response_json
284 .get("choices")
285 .and_then(|c| c.as_array())
286 .ok_or_else(|| anyhow::anyhow!("No choices in OpenAI response"))?;
287
288 if choices.is_empty() {
289 return Ok(ChatResponse {
290 response: "".to_string(),
291 provider: "openai".to_string(),
292 model: self.config.features.model.clone(),
293 tokens_used: 0,
294 error: "No response choices returned from OpenAI".to_string(),
295 });
296 }
297
298 let choice = &choices[0];
299 let message = choice.get("message");
300 let finish_reason = choice.get("finish_reason");
301
302
303 let content = message
304 .and_then(|m| m.get("content"))
305 .and_then(|c| c.as_str())
306 .unwrap_or("")
307 .to_string();
308
309
310 let tokens_used = response_json
312 .get("usage")
313 .and_then(|u| u.get("total_tokens"))
314 .and_then(|t| t.as_i64())
315 .unwrap_or(0) as i32;
316
317 let model_used = response_json
319 .get("model")
320 .and_then(|m| m.as_str())
321 .unwrap_or(&self.config.features.model)
322 .to_string();
323
324 if let Some(reason) = finish_reason.and_then(|r| r.as_str()) {
326 if reason == "length" {
327 if content.is_empty() && self.config.features.model.starts_with("gpt-5") {
329 return Ok(ChatResponse {
330 response: "I understand your question, but my response was truncated due to token limits. Could you please ask a more specific question about the API?".to_string(),
331 provider: "openai".to_string(),
332 model: model_used,
333 tokens_used,
334 error: String::new(),
335 });
336 }
337 }
338 }
339
340 Ok(ChatResponse {
341 response: content,
342 provider: "openai".to_string(),
343 model: model_used,
344 tokens_used,
345 error: String::new(),
346 })
347 }
348
349 fn get_provider(&self) -> &str {
350 "openai"
351 }
352
353 fn get_model(&self) -> &str {
354 &self.config.features.model
355 }
356}
357
358impl GeminiClient {
359 pub fn new(config: &AIConfig) -> Result<Box<dyn Client>> {
360 Ok(Box::new(Self {
361 config: config.clone(),
362 }))
363 }
364}
365
366#[async_trait::async_trait]
367impl Client for GeminiClient {
368 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
369 Ok(ChatResponse {
371 response: format!("Gemini response to: {}", request.message),
372 provider: "gemini".to_string(),
373 model: self.config.features.model.clone(),
374 tokens_used: 100,
375 error: String::new(),
376 })
377 }
378
379 fn get_provider(&self) -> &str {
380 "gemini"
381 }
382
383 fn get_model(&self) -> &str {
384 &self.config.features.model
385 }
386}
387
388impl OpenRouterClient {
389 pub fn new(config: &AIConfig) -> Result<Box<dyn Client>> {
390 if config.api_key.is_empty() {
391 return Err(anyhow::anyhow!("OpenRouter API key is required"));
392 }
393 Ok(Box::new(Self {
394 config: config.clone(),
395 }))
396 }
397
398 fn build_system_prompt(&self, request: &ChatRequest) -> String {
399 let mut prompt = String::from("You are a helpful AI assistant that provides information about APIs. ");
400
401 if let Some(ref context) = request.context {
402 if !context.is_empty() {
403 prompt.push_str("Here's the API documentation context:\n\n");
404 prompt.push_str(context);
405 prompt.push_str("\n\nPlease help the user understand this API based on the provided documentation.");
406 }
407 }
408
409 prompt
410 }
411}
412
413#[async_trait::async_trait]
414impl Client for OpenRouterClient {
415 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
416 let client = reqwest::Client::new();
417
418 let system_prompt = self.build_system_prompt(&request);
420 let messages = vec![
421 json!({
422 "role": "system",
423 "content": system_prompt
424 }),
425 json!({
426 "role": "user",
427 "content": request.message
428 })
429 ];
430
431 let mut body = json!({
433 "model": self.config.features.model,
434 "messages": messages
435 });
436
437 if self.config.features.max_tokens > 0 {
439 body["max_tokens"] = json!(self.config.features.max_tokens);
440 }
441 if self.config.features.max_completion_tokens > 0 {
442 body["max_completion_tokens"] = json!(self.config.features.max_completion_tokens);
443 }
444 if self.config.features.temperature > 0 {
445 body["temperature"] = json!(self.config.features.temperature);
446 }
447
448 let response = client
450 .post("https://openrouter.ai/api/v1/chat/completions")
451 .header("Authorization", format!("Bearer {}", self.config.api_key))
452 .header("Content-Type", "application/json")
453 .header("HTTP-Referer", "https://bytedocs.rs") .header("X-Title", "ByteDocs") .json(&body)
456 .send()
457 .await
458 .map_err(|e| anyhow::anyhow!("Failed to send request to OpenRouter: {}", e))?;
459
460 if !response.status().is_success() {
461 let error_text = response.text().await.unwrap_or_default();
462 return Ok(ChatResponse {
463 response: "".to_string(),
464 provider: "openrouter".to_string(),
465 model: self.config.features.model.clone(),
466 tokens_used: 0,
467 error: format!("OpenRouter API error: {}", error_text),
468 });
469 }
470
471 let response_json: serde_json::Value = response
473 .json()
474 .await
475 .map_err(|e| anyhow::anyhow!("Failed to parse OpenRouter response: {}", e))?;
476
477 let choices = response_json
479 .get("choices")
480 .and_then(|c| c.as_array())
481 .ok_or_else(|| anyhow::anyhow!("No choices in OpenRouter response"))?;
482
483 if choices.is_empty() {
484 return Ok(ChatResponse {
485 response: "".to_string(),
486 provider: "openrouter".to_string(),
487 model: self.config.features.model.clone(),
488 tokens_used: 0,
489 error: "No response choices returned from OpenRouter".to_string(),
490 });
491 }
492
493 let content = choices[0]
494 .get("message")
495 .and_then(|m| m.get("content"))
496 .and_then(|c| c.as_str())
497 .unwrap_or("")
498 .to_string();
499
500 let tokens_used = response_json
502 .get("usage")
503 .and_then(|u| u.get("total_tokens"))
504 .and_then(|t| t.as_i64())
505 .unwrap_or(0) as i32;
506
507 let model_used = response_json
509 .get("model")
510 .and_then(|m| m.as_str())
511 .unwrap_or(&self.config.features.model)
512 .to_string();
513
514 Ok(ChatResponse {
515 response: content,
516 provider: "openrouter".to_string(),
517 model: model_used,
518 tokens_used,
519 error: String::new(),
520 })
521 }
522
523 fn get_provider(&self) -> &str {
524 "openrouter"
525 }
526
527 fn get_model(&self) -> &str {
528 &self.config.features.model
529 }
530}
531
532pub fn init_client_factories() {
534 register_client_factory("openai", |config| OpenAIClient::new(config));
535 register_client_factory("gemini", |config| GeminiClient::new(config));
536 register_client_factory("openrouter", |config| OpenRouterClient::new(config));
537}