1use anyhow::Result;
2use futures_util::StreamExt;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7use crate::template_processor::TemplateProcessor;
8
9#[derive(Debug, Serialize)]
10pub struct ChatRequest {
11 pub model: String,
12 pub messages: Vec<Message>,
13 pub max_tokens: Option<u32>,
14 pub temperature: Option<f32>,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub tools: Option<Vec<Tool>>,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub stream: Option<bool>,
19}
20
21#[derive(Debug, Serialize)]
23pub struct ChatRequestWithoutModel {
24 pub messages: Vec<Message>,
25 pub max_tokens: Option<u32>,
26 pub temperature: Option<f32>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub tools: Option<Vec<Tool>>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub stream: Option<bool>,
31}
32
33impl From<&ChatRequest> for ChatRequestWithoutModel {
34 fn from(request: &ChatRequest) -> Self {
35 Self {
36 messages: request.messages.clone(),
37 max_tokens: request.max_tokens,
38 temperature: request.temperature,
39 tools: request.tools.clone(),
40 stream: request.stream,
41 }
42 }
43}
44
45#[derive(Debug, Serialize)]
46pub struct EmbeddingRequest {
47 pub model: String,
48 pub input: String,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub encoding_format: Option<String>,
51}
52
53#[derive(Debug, Serialize)]
54pub struct ImageGenerationRequest {
55 pub prompt: String,
56 pub model: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub n: Option<u32>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub size: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub quality: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub style: Option<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub response_format: Option<String>,
67}
68
69#[derive(Debug, Serialize)]
70pub struct AudioTranscriptionRequest {
71 pub file: String, pub model: String,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub language: Option<String>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub prompt: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub response_format: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
80 pub temperature: Option<f32>,
81}
82
83#[derive(Debug, Deserialize)]
84pub struct AudioTranscriptionResponse {
85 pub text: String,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 #[allow(dead_code)]
88 pub language: Option<String>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 #[allow(dead_code)]
91 pub duration: Option<f32>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 #[allow(dead_code)]
94 pub segments: Option<Vec<TranscriptionSegment>>,
95}
96
97#[derive(Debug, Deserialize)]
98pub struct TranscriptionSegment {
99 #[allow(dead_code)]
100 pub id: i32,
101 #[allow(dead_code)]
102 pub start: f32,
103 #[allow(dead_code)]
104 pub end: f32,
105 #[allow(dead_code)]
106 pub text: String,
107}
108
109#[derive(Debug, Serialize)]
110pub struct AudioSpeechRequest {
111 pub model: String, pub input: String, pub voice: String, #[serde(skip_serializing_if = "Option::is_none")]
115 pub response_format: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
117 pub speed: Option<f32>, }
119
120#[derive(Debug, Deserialize)]
121pub struct ImageGenerationResponse {
122 pub data: Vec<ImageData>,
123}
124
125#[derive(Debug, Deserialize, Clone)]
126pub struct ImageData {
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub url: Option<String>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 pub b64_json: Option<String>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 pub revised_prompt: Option<String>,
133}
134
135#[derive(Debug, Deserialize)]
136pub struct EmbeddingResponse {
137 pub data: Vec<EmbeddingData>,
138 pub usage: EmbeddingUsage,
139}
140
141#[derive(Debug, Deserialize, Clone)]
142pub struct EmbeddingData {
143 pub embedding: Vec<f64>,
144}
145
146#[derive(Debug, Deserialize, Clone)]
147pub struct EmbeddingUsage {
148 pub total_tokens: u32,
149}
150
151#[derive(Debug, Serialize, Clone)]
152pub struct Tool {
153 #[serde(rename = "type")]
154 pub tool_type: String,
155 pub function: Function,
156}
157
158#[derive(Debug, Serialize, Clone)]
159pub struct Function {
160 pub name: String,
161 pub description: String,
162 pub parameters: serde_json::Value,
163}
164
165#[derive(Debug, Serialize, Deserialize, Clone)]
167pub struct Message {
168 pub role: String,
169 #[serde(flatten)]
170 pub content_type: MessageContent,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub tool_calls: Option<Vec<ToolCall>>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub tool_call_id: Option<String>,
175}
176
177#[derive(Debug, Serialize, Deserialize, Clone)]
179#[serde(untagged)]
180pub enum MessageContent {
181 Text { content: Option<String> },
182 Multimodal { content: Vec<ContentPart> },
183}
184
185#[derive(Debug, Serialize, Deserialize, Clone)]
187#[serde(tag = "type")]
188pub enum ContentPart {
189 #[serde(rename = "text")]
190 Text { text: String },
191 #[serde(rename = "image_url")]
192 ImageUrl { image_url: ImageUrl },
193}
194
195#[derive(Debug, Serialize, Deserialize, Clone)]
196pub struct ImageUrl {
197 pub url: String,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 pub detail: Option<String>, }
201
202impl Message {
203 pub fn user(content: String) -> Self {
204 Self {
205 role: "user".to_string(),
206 content_type: MessageContent::Text {
207 content: Some(content),
208 },
209 tool_calls: None,
210 tool_call_id: None,
211 }
212 }
213
214 #[allow(dead_code)]
215 pub fn user_with_image(text: String, image_data: String, detail: Option<String>) -> Self {
216 Self {
217 role: "user".to_string(),
218 content_type: MessageContent::Multimodal {
219 content: vec![
220 ContentPart::Text { text },
221 ContentPart::ImageUrl {
222 image_url: ImageUrl {
223 url: image_data,
224 detail,
225 },
226 },
227 ],
228 },
229 tool_calls: None,
230 tool_call_id: None,
231 }
232 }
233
234 pub fn assistant(content: String) -> Self {
235 Self {
236 role: "assistant".to_string(),
237 content_type: MessageContent::Text {
238 content: Some(content),
239 },
240 tool_calls: None,
241 tool_call_id: None,
242 }
243 }
244
245 pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
246 Self {
247 role: "assistant".to_string(),
248 content_type: MessageContent::Text { content: None },
249 tool_calls: Some(tool_calls),
250 tool_call_id: None,
251 }
252 }
253
254 pub fn tool_result(tool_call_id: String, content: String) -> Self {
255 Self {
256 role: "tool".to_string(),
257 content_type: MessageContent::Text {
258 content: Some(content),
259 },
260 tool_calls: None,
261 tool_call_id: Some(tool_call_id),
262 }
263 }
264
265 pub fn get_text_content(&self) -> Option<&String> {
267 match &self.content_type {
268 MessageContent::Text { content } => content.as_ref(),
269 MessageContent::Multimodal { content } => {
270 content.iter().find_map(|part| match part {
272 ContentPart::Text { text } => Some(text),
273 _ => None,
274 })
275 }
276 }
277 }
278}
279
280#[derive(Debug, Deserialize)]
281pub struct ChatResponse {
282 pub choices: Vec<Choice>,
283}
284
285#[derive(Debug, Deserialize)]
286pub struct Choice {
287 pub message: ResponseMessage,
288}
289
290#[derive(Debug, Deserialize)]
291pub struct ResponseMessage {
292 #[allow(dead_code)]
293 pub role: String,
294 pub content: Option<String>,
295 pub tool_calls: Option<Vec<ToolCall>>,
296}
297
298#[derive(Debug, Serialize, Deserialize, Clone)]
299pub struct ToolCall {
300 pub id: String,
301 #[serde(rename = "type")]
302 pub call_type: String,
303 pub function: FunctionCall,
304}
305
306#[derive(Debug, Serialize, Deserialize, Clone)]
307pub struct FunctionCall {
308 pub name: String,
309 pub arguments: String,
310}
311
312#[derive(Debug, Deserialize)]
313pub struct ModelsResponse {
314 #[serde(alias = "models")]
315 pub data: Vec<Model>,
316}
317
318#[derive(Debug, Deserialize)]
319pub struct Provider {
320 pub provider: String,
321 #[allow(dead_code)]
322 pub status: String,
323 #[serde(default)]
324 #[allow(dead_code)]
325 pub supports_tools: bool,
326 #[serde(default)]
327 #[allow(dead_code)]
328 pub supports_structured_output: bool,
329}
330
331#[derive(Debug, Deserialize)]
332pub struct Model {
333 pub id: String,
334 #[serde(default = "default_object_type")]
335 pub object: String,
336 #[serde(default)]
337 pub providers: Vec<Provider>,
338}
339
340fn default_object_type() -> String {
341 "model".to_string()
342}
343
344#[derive(Debug, Deserialize)]
345pub struct TokenResponse {
346 pub token: String,
347 pub expires_at: i64, }
349
350pub struct OpenAIClient {
351 client: Client,
352 streaming_client: Client, base_url: String,
354 api_key: String,
355 models_path: String,
356 chat_path: String,
357 custom_headers: std::collections::HashMap<String, String>,
358 provider_config: Option<crate::config::ProviderConfig>,
359 template_processor: Option<TemplateProcessor>,
360}
361
362impl OpenAIClient {
363 pub fn create_http_client(
366 base_url: String,
367 api_key: String,
368 models_path: String,
369 chat_path: String,
370 custom_headers: std::collections::HashMap<String, String>,
371 provider_config: Option<crate::config::ProviderConfig>,
372 ) -> Result<Self> {
373 let default_headers = Self::create_default_headers();
375
376 let client = Self::build_http_client(default_headers.clone(), Duration::from_secs(60))?;
378
379 let streaming_client = Self::build_http_client(default_headers, Duration::from_secs(300))?;
381
382 let template_processor = provider_config
384 .as_ref()
385 .and_then(|config| Self::create_template_processor(config));
386
387 Ok(Self {
388 client,
389 streaming_client,
390 base_url: base_url.trim_end_matches('/').to_string(),
391 api_key,
392 models_path,
393 chat_path,
394 custom_headers,
395 provider_config,
396 template_processor,
397 })
398 }
399
400 pub fn new_with_headers(
402 base_url: String,
403 api_key: String,
404 models_path: String,
405 chat_path: String,
406 custom_headers: std::collections::HashMap<String, String>,
407 ) -> Self {
408 Self::create_http_client(
409 base_url,
410 api_key,
411 models_path,
412 chat_path,
413 custom_headers,
414 None,
415 )
416 .expect("Failed to create OpenAI client")
417 }
418
419 pub fn new_with_provider_config(
421 base_url: String,
422 api_key: String,
423 models_path: String,
424 chat_path: String,
425 custom_headers: std::collections::HashMap<String, String>,
426 provider_config: crate::config::ProviderConfig,
427 ) -> Self {
428 Self::create_http_client(
429 base_url,
430 api_key,
431 models_path,
432 chat_path,
433 custom_headers,
434 Some(provider_config),
435 )
436 .expect("Failed to create OpenAI client with provider config")
437 }
438
439 fn create_default_headers() -> reqwest::header::HeaderMap {
441 use reqwest::header::{HeaderName, HeaderValue};
442
443 let mut headers = reqwest::header::HeaderMap::new();
444 headers.insert(
445 HeaderName::from_static("http-referer"),
446 HeaderValue::from_static("https://lc.viwq.dev/"),
447 );
448 headers.insert(
449 HeaderName::from_static("x-title"),
450 HeaderValue::from_static("lc"),
451 );
452 headers
453 }
454
455 fn build_http_client(
457 default_headers: reqwest::header::HeaderMap,
458 timeout: Duration,
459 ) -> Result<Client> {
460 let mut builder = Client::builder()
461 .pool_max_idle_per_host(10) .pool_idle_timeout(Duration::from_secs(90)) .tcp_keepalive(Duration::from_secs(60)) .timeout(timeout)
465 .connect_timeout(Duration::from_secs(10)) .user_agent(concat!(
467 env!("CARGO_PKG_NAME"),
468 "/",
469 env!("CARGO_PKG_VERSION")
470 ))
471 .default_headers(default_headers);
472
473 if std::env::var("LC_DISABLE_TLS_VERIFY").is_ok() {
475 builder = builder.danger_accept_invalid_certs(true);
476 }
477
478 builder
479 .build()
480 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))
481 }
482
483 fn create_template_processor(
485 config: &crate::config::ProviderConfig,
486 ) -> Option<TemplateProcessor> {
487 let has_templates = config.chat_templates.is_some()
488 || config.images_templates.is_some()
489 || config.embeddings_templates.is_some()
490 || config.models_templates.is_some()
491 || config.speech_templates.is_some();
492
493 if has_templates {
494 match TemplateProcessor::new() {
495 Ok(processor) => Some(processor),
496 Err(e) => {
497 eprintln!("Warning: Failed to create template processor: {}", e);
498 None
499 }
500 }
501 } else {
502 None
503 }
504 }
505
506 fn get_chat_url(&self, model: &str) -> String {
508 if let Some(ref config) = self.provider_config {
509 config.get_chat_url(model)
511 } else {
512 if self.chat_path.starts_with("https://") {
514 self.chat_path
516 .replace("{model_name}", model)
517 .replace("{model}", model)
518 } else {
519 format!("{}{}", self.base_url, self.chat_path)
521 }
522 }
523 }
524
525 fn build_url(&self, endpoint_type: &str, model: &str, default_path: &str) -> String {
527 match endpoint_type {
528 "models" => format!("{}{}", self.base_url, self.models_path),
529 "embeddings" => {
530 if let Some(ref config) = self.provider_config {
531 config.get_embeddings_url(model)
532 } else {
533 format!("{}/embeddings", self.base_url)
534 }
535 }
536 "images" => {
537 if let Some(ref config) = self.provider_config {
538 config.get_images_url(model)
539 } else {
540 format!("{}/images/generations", self.base_url)
541 }
542 }
543 "audio_transcriptions" => {
544 if let Some(ref config) = self.provider_config {
545 format!(
546 "{}{}",
547 self.base_url,
548 config
549 .audio_path
550 .as_deref()
551 .unwrap_or("/audio/transcriptions")
552 )
553 } else {
554 format!("{}/audio/transcriptions", self.base_url)
555 }
556 }
557 "audio_speech" => {
558 if let Some(ref config) = self.provider_config {
559 config.get_speech_url(model)
560 } else {
561 format!("{}/audio/speech", self.base_url)
562 }
563 }
564 _ => {
565 format!("{}{}", self.base_url, default_path)
567 }
568 }
569 }
570
571 fn add_standard_headers(&self, mut req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
573 if !self.custom_headers.contains_key("Authorization")
575 && !self.custom_headers.contains_key("authorization") {
576 req = req.header("Authorization", format!("Bearer {}", self.api_key));
577 }
578
579 for (name, value) in &self.custom_headers {
581 req = req.header(name, value);
582 }
583
584 req
585 }
586
587 pub async fn chat(&self, request: &ChatRequest) -> Result<String> {
588 let url = self.get_chat_url(&request.model);
589
590 let mut req = self
591 .client
592 .post(&url)
593 .header("Content-Type", "application/json");
594
595 if request.stream == Some(true) {
597 req = req.header("Accept-Encoding", "identity");
598 }
599
600 if !self.custom_headers.contains_key("Authorization")
603 && !self.custom_headers.contains_key("authorization") {
604 req = req.header("Authorization", format!("Bearer {}", self.api_key));
605 }
606
607 for (name, value) in &self.custom_headers {
609 req = req.header(name, value);
610 }
611
612 let request_body = if let Some(ref config) = &self.provider_config {
614 if let Some(ref processor) = &self.template_processor {
615 let template = config.get_endpoint_template("chat", &request.model);
617
618 if let Some(template_str) = template {
619 let mut processor_clone = processor.clone();
621 match processor_clone.process_request(request, &template_str, &config.vars) {
623 Ok(json_value) => Some(json_value),
624 Err(e) => {
625 eprintln!("Warning: Failed to process request template: {}. Falling back to default.", e);
626 None
627 }
628 }
629 } else {
630 None
631 }
632 } else {
633 None
634 }
635 } else {
636 None
637 };
638
639 let response = if let Some(json_body) = request_body {
641 req.json(&json_body).send().await?
642 } else {
643 let should_exclude_model = if let Some(ref config) = self.provider_config {
646 config.chat_path.contains("{model}")
647 } else {
648 self.chat_path.contains("{model}")
649 };
650
651 if should_exclude_model {
652 let request_without_model = ChatRequestWithoutModel::from(request);
654 req.json(&request_without_model).send().await?
655 } else {
656 req.json(request).send().await?
657 }
658 };
659
660 if !response.status().is_success() {
661 let status = response.status();
662 let text = response.text().await.unwrap_or_default();
663 anyhow::bail!("API request failed with status {}: {}", status, text);
664 }
665
666 let response_text = response.text().await?;
668
669 if let Some(ref config) = &self.provider_config {
671 if let Some(ref processor) = &self.template_processor {
672 let template = config.get_endpoint_response_template("chat", &request.model);
674
675 if let Some(template_str) = template {
676 if let Ok(response_json) =
678 serde_json::from_str::<serde_json::Value>(&response_text)
679 {
680 let mut processor_clone = processor.clone();
682 match processor_clone.process_response(&response_json, &template_str) {
684 Ok(extracted) => {
685 if let Some(content) =
687 extracted.get("content").and_then(|v| v.as_str())
688 {
689 return Ok(content.to_string());
690 } else if let Some(tool_calls) =
691 extracted.get("tool_calls").and_then(|v| v.as_array())
692 {
693 if !tool_calls.is_empty() {
694 let mut response = String::new();
695 response.push_str("🔧 **Tool Calls Made:**\n\n");
696 response
697 .push_str(&format!("Tool calls: {:?}\n\n", tool_calls));
698 response.push_str("*Tool calls detected - execution handled by chat module*\n\n");
699 return Ok(response);
700 }
701 }
702 }
703 Err(e) => {
704 eprintln!("Warning: Failed to process response template: {}. Falling back to default parsing.", e);
705 }
706 }
707 }
708 }
709 }
710 }
711
712 if let Ok(chat_response) = serde_json::from_str::<ChatResponse>(&response_text) {
715 if let Some(choice) = chat_response.choices.first() {
716 if let Some(tool_calls) = &choice.message.tool_calls {
718 if !tool_calls.is_empty() {
719 let mut response = String::new();
720 response.push_str("🔧 **Tool Calls Made:**\n\n");
721
722 for tool_call in tool_calls {
723 response.push_str(&format!(
724 "**Function:** `{}`\n",
725 tool_call.function.name
726 ));
727 response.push_str(&format!(
728 "**Arguments:** `{}`\n\n",
729 tool_call.function.arguments
730 ));
731
732 response.push_str(
734 "*Tool calls detected - execution handled by chat module*\n\n",
735 );
736 }
737
738 return Ok(response);
739 }
740 }
742
743 if let Some(content) = &choice.message.content {
745 return Ok(content.clone());
746 } else {
747 anyhow::bail!("No content or tool calls in response");
748 }
749 } else {
750 anyhow::bail!("No response from API");
751 }
752 }
753
754 anyhow::bail!("Failed to parse chat response. Response: {}", response_text);
756 }
757
758 pub async fn list_models(&self) -> Result<Vec<Model>> {
759 let url = format!("{}{}", self.base_url, self.models_path);
760
761 crate::debug_log!("Requesting models from URL: {}", url);
763
764 let mut req = self
765 .client
766 .get(&url)
767 .header("Content-Type", "application/json");
768
769 req = self.add_standard_headers(req);
771
772 let response = req.send().await?;
773
774 if !response.status().is_success() {
775 let status = response.status();
776 let text = response.text().await.unwrap_or_default();
777 crate::debug_log!("API request failed with status {}: {}", status, text);
778 anyhow::bail!("API request failed with status {}: {}", status, text);
779 }
780
781 let response_text = response.text().await?;
783
784 crate::debug_log!(
786 "Received models response ({} bytes): {}",
787 response_text.len(),
788 response_text
789 );
790
791 let models = if let Ok(models_response) =
793 serde_json::from_str::<ModelsResponse>(&response_text)
794 {
795 models_response.data
796 } else if let Ok(parsed_models) = serde_json::from_str::<Vec<Model>>(&response_text) {
797 parsed_models
799 } else {
800 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response_text) {
802 if let Some(models_array) = json_value.get("models").and_then(|v| v.as_array()) {
803 let mut converted_models = Vec::new();
805 for model_json in models_array {
806 if let Some(name) = model_json.get("name").and_then(|v| v.as_str()) {
808 let id = if name.starts_with("models/") {
810 &name[7..]
811 } else {
812 name
813 };
814
815 converted_models.push(Model {
816 id: id.to_string(),
817 object: "model".to_string(),
818 providers: vec![], });
820 }
821 }
822
823 if !converted_models.is_empty() {
824 crate::debug_log!(
825 "Successfully parsed {} Gemini models",
826 converted_models.len()
827 );
828 converted_models
829 } else {
830 anyhow::bail!(
831 "Failed to parse models response. Response: {}",
832 response_text
833 );
834 }
835 } else {
836 anyhow::bail!(
838 "Failed to parse models response. Response: {}",
839 response_text
840 );
841 }
842 } else {
843 anyhow::bail!(
844 "Failed to parse models response. Response: {}",
845 response_text
846 );
847 }
848 };
849
850 let mut expanded_models = Vec::new();
852
853 for model in models {
854 if model.providers.is_empty() {
855 expanded_models.push(model);
857 } else {
858 for provider in &model.providers {
860 let expanded_model = Model {
861 id: format!("{}:{}", model.id, provider.provider),
862 object: model.object.clone(),
863 providers: vec![], };
865 expanded_models.push(expanded_model);
866 }
867 }
868 }
869
870 Ok(expanded_models)
871 }
872
873 pub async fn chat_with_tools(&self, request: &ChatRequest) -> Result<ChatResponse> {
875 let url = self.get_chat_url(&request.model);
876
877 let mut req = self
878 .client
879 .post(&url)
880 .header("Content-Type", "application/json");
881
882 if request.stream == Some(true) {
884 req = req.header("Accept-Encoding", "identity");
885 }
886
887 if !self.custom_headers.contains_key("Authorization")
889 && !self.custom_headers.contains_key("authorization") {
890 req = req.header("Authorization", format!("Bearer {}", self.api_key));
891 }
892
893 for (name, value) in &self.custom_headers {
895 req = req.header(name, value);
896 }
897
898 let should_exclude_model = if let Some(ref config) = self.provider_config {
900 config.chat_path.contains("{model}")
901 } else {
902 self.chat_path.contains("{model}")
903 };
904
905 let response = if should_exclude_model {
906 let request_without_model = ChatRequestWithoutModel::from(request);
908 req.json(&request_without_model).send().await?
909 } else {
910 req.json(request).send().await?
911 };
912
913 if !response.status().is_success() {
914 let status = response.status();
915 let text = response.text().await.unwrap_or_default();
916 anyhow::bail!("API request failed with status {}: {}", status, text);
917 }
918
919 let response_text = response.text().await?;
921
922 if let Ok(chat_response) = serde_json::from_str::<ChatResponse>(&response_text) {
924 return Ok(chat_response);
925 }
926
927 anyhow::bail!("Failed to parse chat response. Response: {}", response_text);
929 }
930
931 pub async fn get_token_from_url(&self, token_url: &str) -> Result<TokenResponse> {
932 let mut req = self
933 .client
934 .get(token_url)
935 .header("Authorization", format!("token {}", self.api_key))
936 .header("Content-Type", "application/json");
937
938 for (name, value) in &self.custom_headers {
940 req = req.header(name, value);
941 }
942
943 let response = req.send().await?;
944
945 if !response.status().is_success() {
946 let status = response.status();
947 let text = response.text().await.unwrap_or_default();
948 anyhow::bail!("Token request failed with status {}: {}", status, text);
949 }
950
951 let token_response: TokenResponse = response.json().await?;
952 Ok(token_response)
953 }
954
955 pub async fn embeddings(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
956 let url = self.build_url("embeddings", &request.model, "/embeddings");
958
959 let mut req = self
960 .client
961 .post(&url)
962 .header("Content-Type", "application/json");
963
964 req = self.add_standard_headers(req);
966
967 let request_body = if let Some(ref config) = &self.provider_config {
969 if let Some(ref processor) = &self.template_processor {
970 let template = config.get_endpoint_template("embeddings", &request.model);
972
973 if let Some(template_str) = template {
974 let mut processor_clone = processor.clone();
976 match processor_clone.process_embeddings_request(
978 request,
979 &template_str,
980 &config.vars,
981 ) {
982 Ok(json_value) => Some(json_value),
983 Err(e) => {
984 eprintln!("Warning: Failed to process embeddings request template: {}. Falling back to default.", e);
985 None
986 }
987 }
988 } else {
989 None
990 }
991 } else {
992 None
993 }
994 } else {
995 None
996 };
997
998 let response = if let Some(json_body) = request_body {
1000 req.json(&json_body).send().await?
1001 } else {
1002 req.json(request).send().await?
1003 };
1004
1005 if !response.status().is_success() {
1006 let status = response.status();
1007 let text = response.text().await.unwrap_or_default();
1008 anyhow::bail!(
1009 "Embeddings API request failed with status {}: {}",
1010 status,
1011 text
1012 );
1013 }
1014
1015 let response_text = response.text().await?;
1017
1018 if let Some(ref config) = &self.provider_config {
1020 if let Some(ref processor) = &self.template_processor {
1021 let template = config.get_endpoint_response_template("embeddings", &request.model);
1023
1024 if let Some(template_str) = template {
1025 if let Ok(response_json) =
1027 serde_json::from_str::<serde_json::Value>(&response_text)
1028 {
1029 let mut processor_clone = processor.clone();
1031 match processor_clone.process_response(&response_json, &template_str) {
1033 Ok(transformed) => {
1034 if let Ok(embedding_response) =
1036 serde_json::from_value::<EmbeddingResponse>(transformed)
1037 {
1038 return Ok(embedding_response);
1039 }
1040 }
1041 Err(e) => {
1042 eprintln!("Warning: Failed to process embeddings response template: {}. Falling back to default parsing.", e);
1043 }
1044 }
1045 }
1046 }
1047 }
1048 }
1049
1050 let embedding_response: EmbeddingResponse = serde_json::from_str(&response_text)?;
1052 Ok(embedding_response)
1053 }
1054
1055 pub async fn generate_images(
1056 &self,
1057 request: &ImageGenerationRequest,
1058 ) -> Result<ImageGenerationResponse> {
1059 let model_name = request.model.as_deref().unwrap_or("");
1061 let url = self.build_url("images", model_name, "/images/generations");
1062
1063 let mut req = self
1064 .client
1065 .post(&url)
1066 .header("Content-Type", "application/json");
1067
1068 req = self.add_standard_headers(req);
1070
1071 let request_body = if let Some(ref config) = &self.provider_config {
1073 if let Some(ref processor) = &self.template_processor {
1074 let model_name = request.model.as_deref().unwrap_or("");
1076 let template = config.get_endpoint_template("images", model_name);
1077
1078 if let Some(template_str) = template {
1079 let mut processor_clone = processor.clone();
1081 match processor_clone.process_image_request(
1083 request,
1084 &template_str,
1085 &config.vars,
1086 ) {
1087 Ok(json_value) => Some(json_value),
1088 Err(e) => {
1089 eprintln!("Warning: Failed to process image request template: {}. Falling back to default.", e);
1090 None
1091 }
1092 }
1093 } else {
1094 None
1095 }
1096 } else {
1097 None
1098 }
1099 } else {
1100 None
1101 };
1102
1103 let response = if let Some(json_body) = request_body {
1105 req.json(&json_body).send().await?
1106 } else {
1107 req.json(request).send().await?
1108 };
1109
1110 if !response.status().is_success() {
1111 let status = response.status();
1112 let text = response.text().await.unwrap_or_default();
1113 anyhow::bail!(
1114 "Image generation API request failed with status {}: {}",
1115 status,
1116 text
1117 );
1118 }
1119
1120 let response_text = response.text().await?;
1122
1123 if let Some(ref config) = &self.provider_config {
1125 if let Some(ref processor) = &self.template_processor {
1126 let model_name = request.model.as_deref().unwrap_or("");
1128 let template = config.get_endpoint_response_template("images", model_name);
1129
1130 if let Some(template_str) = template {
1131 if let Ok(response_json) =
1133 serde_json::from_str::<serde_json::Value>(&response_text)
1134 {
1135 let mut processor_clone = processor.clone();
1137 match processor_clone.process_response(&response_json, &template_str) {
1139 Ok(transformed) => {
1140 if let Ok(image_response) =
1142 serde_json::from_value::<ImageGenerationResponse>(transformed)
1143 {
1144 return Ok(image_response);
1145 }
1146 }
1147 Err(e) => {
1148 eprintln!("Warning: Failed to process image response template: {}. Falling back to default parsing.", e);
1149 }
1150 }
1151 }
1152 }
1153 }
1154 }
1155
1156 let image_response: ImageGenerationResponse = serde_json::from_str(&response_text)?;
1158 Ok(image_response)
1159 }
1160 pub async fn transcribe_audio(
1161 &self,
1162 request: &AudioTranscriptionRequest,
1163 ) -> Result<AudioTranscriptionResponse> {
1164 use reqwest::multipart;
1165
1166 let url = self.build_url(
1168 "audio_transcriptions",
1169 &request.model,
1170 "/audio/transcriptions",
1171 );
1172
1173 use base64::Engine;
1175 let audio_bytes = if request.file.starts_with("data:") {
1176 let parts: Vec<&str> = request.file.splitn(2, ',').collect();
1178 if parts.len() == 2 {
1179 base64::engine::general_purpose::STANDARD.decode(parts[1])?
1180 } else {
1181 anyhow::bail!("Invalid data URL format");
1182 }
1183 } else {
1184 base64::engine::general_purpose::STANDARD.decode(&request.file)?
1186 };
1187
1188 let file_extension = if request.file.starts_with("data:audio/") {
1191 let mime_part = request.file.split(';').next().unwrap_or("");
1192 match mime_part {
1193 "data:audio/mpeg" | "data:audio/mp3" => "mp3",
1194 "data:audio/wav" | "data:audio/wave" => "wav",
1195 "data:audio/flac" => "flac",
1196 "data:audio/ogg" => "ogg",
1197 "data:audio/webm" => "webm",
1198 "data:audio/mp4" => "mp4",
1199 _ => "wav",
1200 }
1201 } else {
1202 "wav" };
1204
1205 let mut form = multipart::Form::new()
1207 .text("model", request.model.clone())
1208 .part(
1209 "file",
1210 multipart::Part::bytes(audio_bytes)
1211 .file_name(format!("audio.{}", file_extension))
1212 .mime_str(&format!(
1213 "audio/{}",
1214 if file_extension == "mp3" {
1215 "mpeg"
1216 } else {
1217 file_extension
1218 }
1219 ))?,
1220 );
1221
1222 if let Some(language) = &request.language {
1224 form = form.text("language", language.clone());
1225 }
1226 if let Some(prompt) = &request.prompt {
1227 form = form.text("prompt", prompt.clone());
1228 }
1229 if let Some(response_format) = &request.response_format {
1230 form = form.text("response_format", response_format.clone());
1231 }
1232 if let Some(temperature) = request.temperature {
1233 form = form.text("temperature", temperature.to_string());
1234 }
1235
1236 let mut req = self.client.post(&url);
1237
1238 req = self.add_standard_headers(req);
1240
1241 let response = req.multipart(form).send().await?;
1243
1244 if !response.status().is_success() {
1245 let status = response.status();
1246 let text = response.text().await.unwrap_or_default();
1247 anyhow::bail!(
1248 "Audio transcription API request failed with status {}: {}",
1249 status,
1250 text
1251 );
1252 }
1253
1254 let response_text = response.text().await?;
1256
1257 if let Some(ref config) = &self.provider_config {
1259 if let Some(ref processor) = &self.template_processor {
1260 let template = config.get_endpoint_response_template("audio", &request.model);
1262
1263 if let Some(template_str) = template {
1264 if let Ok(response_json) =
1266 serde_json::from_str::<serde_json::Value>(&response_text)
1267 {
1268 let mut processor_clone = processor.clone();
1270 match processor_clone.process_response(&response_json, &template_str) {
1272 Ok(transformed) => {
1273 if let Ok(audio_response) =
1275 serde_json::from_value::<AudioTranscriptionResponse>(
1276 transformed,
1277 )
1278 {
1279 return Ok(audio_response);
1280 }
1281 }
1282 Err(e) => {
1283 eprintln!("Warning: Failed to process audio response template: {}. Falling back to default parsing.", e);
1284 }
1285 }
1286 }
1287 }
1288 }
1289 }
1290
1291 if response_text.starts_with('{') {
1294 let audio_response: AudioTranscriptionResponse = serde_json::from_str(&response_text)?;
1296 Ok(audio_response)
1297 } else {
1298 Ok(AudioTranscriptionResponse {
1300 text: response_text.trim().to_string(),
1301 language: None,
1302 duration: None,
1303 segments: None,
1304 })
1305 }
1306 }
1307
1308 pub async fn generate_speech(&self, request: &AudioSpeechRequest) -> Result<Vec<u8>> {
1309 let url = self.build_url("audio_speech", &request.model, "/audio/speech");
1311
1312 let mut req = self
1313 .client
1314 .post(&url)
1315 .header("Content-Type", "application/json");
1316
1317 req = self.add_standard_headers(req);
1319
1320 let request_body = if let Some(ref config) = &self.provider_config {
1322 if let Some(ref processor) = &self.template_processor {
1323 let template = config.get_endpoint_template("speech", &request.model);
1325
1326 if let Some(template_str) = template {
1327 let mut processor_clone = processor.clone();
1329 match processor_clone.process_speech_request(
1331 request,
1332 &template_str,
1333 &config.vars,
1334 ) {
1335 Ok(json_value) => Some(json_value),
1336 Err(e) => {
1337 eprintln!("Warning: Failed to process speech request template: {}. Falling back to default.", e);
1338 None
1339 }
1340 }
1341 } else {
1342 None
1343 }
1344 } else {
1345 None
1346 }
1347 } else {
1348 None
1349 };
1350
1351 let response = if let Some(json_body) = request_body {
1353 req.json(&json_body).send().await?
1354 } else {
1355 req.json(request).send().await?
1356 };
1357
1358 if !response.status().is_success() {
1359 let status = response.status();
1360 let text = response.text().await.unwrap_or_default();
1361 anyhow::bail!(
1362 "Speech generation API request failed with status {}: {}",
1363 status,
1364 text
1365 );
1366 }
1367
1368 let response_text = response.text().await?;
1370
1371 if let Some(ref config) = &self.provider_config {
1373 if let Some(ref processor) = &self.template_processor {
1374 let template = config.get_endpoint_response_template("speech", &request.model);
1376
1377 if let Some(template_str) = template {
1378 if let Ok(response_json) =
1380 serde_json::from_str::<serde_json::Value>(&response_text)
1381 {
1382 let mut processor_clone = processor.clone();
1384 match processor_clone.process_response(&response_json, &template_str) {
1386 Ok(extracted) => {
1387 if let Some(base64_data) = extracted.as_str() {
1389 use base64::Engine;
1391 match base64::engine::general_purpose::STANDARD
1392 .decode(base64_data)
1393 {
1394 Ok(audio_bytes) => return Ok(audio_bytes),
1395 Err(e) => {
1396 eprintln!("Warning: Failed to decode base64 audio data: {}. Falling back to default parsing.", e);
1397 }
1398 }
1399 }
1400 }
1401 Err(e) => {
1402 eprintln!("Warning: Failed to process speech response template: {}. Falling back to default parsing.", e);
1403 }
1404 }
1405 }
1406 }
1407 }
1408 }
1409
1410 if response_text
1413 .chars()
1414 .all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=')
1415 {
1416 use base64::Engine;
1417 if let Ok(audio_bytes) =
1418 base64::engine::general_purpose::STANDARD.decode(&response_text)
1419 {
1420 return Ok(audio_bytes);
1421 }
1422 }
1423
1424 Ok(response_text.into_bytes())
1426 }
1427
1428 pub async fn chat_stream(&self, request: &ChatRequest) -> Result<()> {
1429 use std::io::{stdout, Write};
1430
1431 let url = self.get_chat_url(&request.model);
1432
1433 let mut req = self
1435 .streaming_client
1436 .post(&url)
1437 .header("Content-Type", "application/json")
1438 .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .header("Accept-Encoding", "identity"); let stdout = stdout();
1444 let mut handle = std::io::BufWriter::new(stdout.lock());
1445
1446 req = self.add_standard_headers(req);
1448
1449 let request_body = if let Some(ref config) = &self.provider_config {
1451 if let Some(ref processor) = &self.template_processor {
1452 let template = config.get_endpoint_template("chat", &request.model);
1454
1455 if let Some(template_str) = template {
1456 let mut processor_clone = processor.clone();
1458 match processor_clone.process_request(request, &template_str, &config.vars) {
1460 Ok(json_value) => Some(json_value),
1461 Err(e) => {
1462 eprintln!("Warning: Failed to process request template: {}. Falling back to default.", e);
1463 None
1464 }
1465 }
1466 } else {
1467 None
1468 }
1469 } else {
1470 None
1471 }
1472 } else {
1473 None
1474 };
1475
1476 let should_exclude_model = if let Some(ref config) = self.provider_config {
1478 config.chat_path.contains("{model}")
1479 } else {
1480 self.chat_path.contains("{model}")
1481 };
1482
1483 let response = if let Some(json_body) = request_body {
1485 req.json(&json_body).send().await?
1486 } else if should_exclude_model {
1487 let request_without_model = ChatRequestWithoutModel::from(request);
1489 req.json(&request_without_model).send().await?
1490 } else {
1491 req.json(request).send().await?
1492 };
1493
1494 if !response.status().is_success() {
1495 let status = response.status();
1496 let text = response.text().await.unwrap_or_default();
1497 anyhow::bail!("API request failed with status {}: {}", status, text);
1498 }
1499
1500 let headers = response.headers();
1502 if headers.get("content-encoding").is_some() {
1503 }
1505
1506 let mut stream = response.bytes_stream();
1507
1508 let mut buffer = String::new();
1509
1510 while let Some(chunk) = stream.next().await {
1511 let chunk = chunk?;
1512
1513 let chunk_str = String::from_utf8_lossy(&chunk);
1514 buffer.push_str(&chunk_str);
1515
1516 while let Some(newline_pos) = buffer.find('\n') {
1518 let line = buffer[..newline_pos].to_string();
1519 buffer.drain(..=newline_pos);
1520
1521 if line.starts_with("data: ") {
1523 let data = &line[6..]; if data.trim() == "[DONE]" {
1526 handle.write_all(b"\n")?;
1527 handle.flush()?;
1528 return Ok(());
1529 }
1530
1531 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
1532 if let Some(response) = json.get("response") {
1534 if let Some(text) = response.as_str() {
1535 if !text.is_empty() {
1536 handle.write_all(text.as_bytes())?;
1537 handle.flush()?;
1538 }
1539 }
1540 }
1541 else if let Some(choices) = json.get("choices") {
1543 if let Some(choice) = choices.get(0) {
1544 if let Some(delta) = choice.get("delta") {
1545 if let Some(content) = delta.get("content") {
1546 if let Some(text) = content.as_str() {
1547 handle.write_all(text.as_bytes())?;
1549 handle.flush()?;
1550 }
1551 }
1552 }
1553 }
1554 }
1555 }
1556 } else if line.trim().is_empty() {
1557 continue;
1559 } else {
1560 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
1562 if let Some(response) = json.get("response") {
1564 if let Some(text) = response.as_str() {
1565 if !text.is_empty() {
1566 handle.write_all(text.as_bytes())?;
1567 handle.flush()?;
1568 }
1569 }
1570 }
1571 else if let Some(choices) = json.get("choices") {
1573 if let Some(choice) = choices.get(0) {
1574 if let Some(delta) = choice.get("delta") {
1575 if let Some(content) = delta.get("content") {
1576 if let Some(text) = content.as_str() {
1577 handle.write_all(text.as_bytes())?;
1578 handle.flush()?;
1579 }
1580 }
1581 }
1582 }
1583 }
1584 }
1585 }
1586 }
1587 }
1588
1589 if !buffer.trim().is_empty() {
1591 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&buffer) {
1592 if let Some(response) = json.get("response") {
1594 if let Some(text) = response.as_str() {
1595 if !text.is_empty() {
1596 handle.write_all(text.as_bytes())?;
1597 handle.flush()?;
1598 }
1599 }
1600 }
1601 else if let Some(choices) = json.get("choices") {
1603 if let Some(choice) = choices.get(0) {
1604 if let Some(delta) = choice.get("delta") {
1605 if let Some(content) = delta.get("content") {
1606 if let Some(text) = content.as_str() {
1607 handle.write_all(text.as_bytes())?;
1608 handle.flush()?;
1609 }
1610 }
1611 }
1612 }
1613 }
1614 }
1615 }
1616
1617 handle.write_all(b"\n")?;
1619 handle.flush()?;
1620 Ok(())
1621 }
1622}