1use super::provider::{ChatStream, LLMProvider, ModelCapabilities, ModelInfo};
31use super::types::*;
32use async_openai::{
33 Client,
34 config::OpenAIConfig as AsyncOpenAIConfig,
35 types::{
36 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
37 ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
38 ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
39 ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
40 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
41 ChatCompletionToolArgs, ChatCompletionToolChoiceOption, ChatCompletionToolType,
42 CreateChatCompletionRequestArgs, FunctionObjectArgs, ImageDetail as OpenAIImageDetail,
43 ImageUrl as OpenAIImageUrl, InputAudio, InputAudioFormat,
44 },
45};
46use async_trait::async_trait;
47use futures::StreamExt;
48
49#[derive(Debug, Clone)]
51pub struct OpenAIConfig {
52 pub api_key: String,
54 pub base_url: Option<String>,
56 pub org_id: Option<String>,
58 pub default_model: String,
60 pub default_temperature: f32,
62 pub default_max_tokens: u32,
64 pub timeout_secs: u64,
66}
67
68impl Default for OpenAIConfig {
69 fn default() -> Self {
70 Self {
71 api_key: String::new(),
72 base_url: None,
73 org_id: None,
74 default_model: "gpt-4o".to_string(),
75 default_temperature: 0.7,
76 default_max_tokens: 4096,
77 timeout_secs: 60,
78 }
79 }
80}
81
82impl OpenAIConfig {
83 pub fn new(api_key: impl Into<String>) -> Self {
85 Self {
86 api_key: api_key.into(),
87 ..Default::default()
88 }
89 }
90
91 pub fn from_env() -> Self {
93 Self {
94 api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
95 base_url: std::env::var("OPENAI_BASE_URL").ok(),
96 default_model: std::env::var("OPENAI_MODEL").unwrap_or_default(),
97 ..Default::default()
98 }
99 }
100
101 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
103 self.base_url = Some(url.into());
104 self
105 }
106
107 pub fn with_model(mut self, model: impl Into<String>) -> Self {
109 self.default_model = model.into();
110 self
111 }
112
113 pub fn with_temperature(mut self, temp: f32) -> Self {
115 self.default_temperature = temp;
116 self
117 }
118
119 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
121 self.default_max_tokens = tokens;
122 self
123 }
124
125 pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
127 self.org_id = Some(org_id.into());
128 self
129 }
130
131 pub fn with_timeout(mut self, secs: u64) -> Self {
133 self.timeout_secs = secs;
134 self
135 }
136}
137
138pub struct OpenAIProvider {
142 client: Client<AsyncOpenAIConfig>,
143 config: OpenAIConfig,
144}
145
146impl OpenAIProvider {
147 pub fn new(api_key: impl Into<String>) -> Self {
149 let config = OpenAIConfig::new(api_key);
150 Self::with_config(config)
151 }
152
153 pub fn from_env() -> Self {
155 Self::with_config(OpenAIConfig::from_env())
156 }
157
158 pub fn with_config(config: OpenAIConfig) -> Self {
160 let mut openai_config = AsyncOpenAIConfig::new().with_api_key(&config.api_key);
161
162 if let Some(ref base_url) = config.base_url {
163 openai_config = openai_config.with_api_base(base_url);
164 }
165
166 if let Some(ref org_id) = config.org_id {
167 openai_config = openai_config.with_org_id(org_id);
168 }
169
170 let client = Client::with_config(openai_config);
171
172 Self { client, config }
173 }
174
175 pub fn azure(
177 endpoint: impl Into<String>,
178 api_key: impl Into<String>,
179 deployment: impl Into<String>,
180 ) -> Self {
181 let endpoint = endpoint.into();
182 let deployment = deployment.into();
183
184 let base_url = format!(
186 "{}/openai/deployments/{}",
187 endpoint.trim_end_matches('/'),
188 deployment
189 );
190
191 let config = OpenAIConfig::new(api_key)
192 .with_base_url(base_url)
193 .with_model(deployment);
194
195 Self::with_config(config)
196 }
197
198 pub fn local(base_url: impl Into<String>, model: impl Into<String>) -> Self {
200 let config = OpenAIConfig::new("not-needed")
201 .with_base_url(base_url)
202 .with_model(model);
203
204 Self::with_config(config)
205 }
206
207 pub fn ollama(model: impl Into<String>) -> Self {
209 Self::local("http://localhost:11434/v1", model)
210 }
211
212 pub fn client(&self) -> &Client<AsyncOpenAIConfig> {
214 &self.client
215 }
216
217 pub fn config(&self) -> &OpenAIConfig {
219 &self.config
220 }
221
222 fn convert_messages(
224 messages: &[ChatMessage],
225 ) -> Result<Vec<ChatCompletionRequestMessage>, LLMError> {
226 messages.iter().map(Self::convert_message).collect()
227 }
228
229 fn convert_message(msg: &ChatMessage) -> Result<ChatCompletionRequestMessage, LLMError> {
231 let text_only_content = msg
232 .content
233 .as_ref()
234 .map(|c| match c {
235 MessageContent::Text(s) => s.clone(),
236 MessageContent::Parts(parts) => parts
237 .iter()
238 .filter_map(|p| match p {
239 ContentPart::Text { text } => Some(text.clone()),
240 _ => None,
241 })
242 .collect::<Vec<_>>()
243 .join("\n"),
244 })
245 .unwrap_or_default();
246
247 match msg.role {
248 Role::System => Ok(ChatCompletionRequestSystemMessageArgs::default()
249 .content(text_only_content)
250 .build()
251 .map_err(|e| LLMError::Other(e.to_string()))?
252 .into()),
253 Role::User => {
254 let content = match msg.content.as_ref() {
255 Some(MessageContent::Text(s)) => {
256 ChatCompletionRequestUserMessageContent::Text(s.clone())
257 }
258 Some(MessageContent::Parts(parts)) => {
259 let mut out = Vec::new();
260 for part in parts {
261 match part {
262 ContentPart::Text { text } => {
263 out.push(ChatCompletionRequestUserMessageContentPart::Text(
264 ChatCompletionRequestMessageContentPartText {
265 text: text.clone(),
266 },
267 ));
268 }
269 ContentPart::Image { image_url } => {
270 let detail = image_url.detail.as_ref().map(|d| match d {
271 ImageDetail::Auto => OpenAIImageDetail::Auto,
272 ImageDetail::Low => OpenAIImageDetail::Low,
273 ImageDetail::High => OpenAIImageDetail::High,
274 });
275 let image_part = ChatCompletionRequestMessageContentPartImage {
276 image_url: OpenAIImageUrl {
277 url: image_url.url.clone(),
278 detail,
279 },
280 };
281 out.push(
282 ChatCompletionRequestUserMessageContentPart::ImageUrl(
283 image_part,
284 ),
285 );
286 }
287 ContentPart::Audio { audio } => {
288 let format = match audio.format.to_lowercase().as_str() {
289 "wav" => InputAudioFormat::Wav,
290 _ => InputAudioFormat::Mp3,
291 };
292 let audio_part = ChatCompletionRequestMessageContentPartAudio {
293 input_audio: InputAudio {
294 data: audio.data.clone(),
295 format,
296 },
297 };
298 out.push(
299 ChatCompletionRequestUserMessageContentPart::InputAudio(
300 audio_part,
301 ),
302 );
303 }
304 }
305 }
306 ChatCompletionRequestUserMessageContent::Array(out)
307 }
308 None => ChatCompletionRequestUserMessageContent::Text(String::new()),
309 };
310
311 Ok(ChatCompletionRequestUserMessageArgs::default()
312 .content(content)
313 .build()
314 .map_err(|e| LLMError::Other(e.to_string()))?
315 .into())
316 }
317 Role::Assistant => {
318 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
319 if !text_only_content.is_empty() {
320 builder.content(text_only_content);
321 }
322
323 if let Some(ref tool_calls) = msg.tool_calls {
325 let converted_calls: Vec<_> = tool_calls
326 .iter()
327 .map(|tc| async_openai::types::ChatCompletionMessageToolCall {
328 id: tc.id.clone(),
329 r#type: ChatCompletionToolType::Function,
330 function: async_openai::types::FunctionCall {
331 name: tc.function.name.clone(),
332 arguments: tc.function.arguments.clone(),
333 },
334 })
335 .collect();
336 builder.tool_calls(converted_calls);
337 }
338
339 Ok(builder
340 .build()
341 .map_err(|e| LLMError::Other(e.to_string()))?
342 .into())
343 }
344 Role::Tool => {
345 let tool_call_id = msg
346 .tool_call_id
347 .clone()
348 .unwrap_or_else(|| "unknown".to_string());
349
350 Ok(ChatCompletionRequestToolMessageArgs::default()
351 .tool_call_id(tool_call_id)
352 .content(text_only_content)
353 .build()
354 .map_err(|e| LLMError::Other(e.to_string()))?
355 .into())
356 }
357 }
358 }
359
360 fn convert_tools(
362 tools: &[Tool],
363 ) -> Result<Vec<async_openai::types::ChatCompletionTool>, LLMError> {
364 tools
365 .iter()
366 .map(|tool| {
367 let function = FunctionObjectArgs::default()
368 .name(&tool.function.name)
369 .description(tool.function.description.clone().unwrap_or_default())
370 .parameters(
371 tool.function
372 .parameters
373 .clone()
374 .unwrap_or(serde_json::json!({})),
375 )
376 .build()
377 .map_err(|e| LLMError::Other(e.to_string()))?;
378
379 ChatCompletionToolArgs::default()
380 .r#type(ChatCompletionToolType::Function)
381 .function(function)
382 .build()
383 .map_err(|e| LLMError::Other(e.to_string()))
384 })
385 .collect()
386 }
387
388 fn convert_response(
390 response: async_openai::types::CreateChatCompletionResponse,
391 ) -> ChatCompletionResponse {
392 let choices: Vec<Choice> = response
393 .choices
394 .into_iter()
395 .map(|choice| {
396 let message = Self::convert_response_message(choice.message);
397 let finish_reason = choice.finish_reason.map(|r| match r {
398 async_openai::types::FinishReason::Stop => FinishReason::Stop,
399 async_openai::types::FinishReason::Length => FinishReason::Length,
400 async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
401 async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
402 async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
403 });
404
405 Choice {
406 index: choice.index,
407 message,
408 finish_reason,
409 logprobs: None,
410 }
411 })
412 .collect();
413
414 let usage = response.usage.map(|u| Usage {
415 prompt_tokens: u.prompt_tokens,
416 completion_tokens: u.completion_tokens,
417 total_tokens: u.total_tokens,
418 });
419
420 ChatCompletionResponse {
421 id: response.id,
422 object: response.object,
423 created: response.created as u64,
424 model: response.model,
425 choices,
426 usage,
427 system_fingerprint: response.system_fingerprint,
428 }
429 }
430
431 fn convert_response_message(
433 msg: async_openai::types::ChatCompletionResponseMessage,
434 ) -> ChatMessage {
435 let content = msg.content.map(MessageContent::Text);
436
437 let tool_calls = msg.tool_calls.map(|calls| {
438 calls
439 .into_iter()
440 .map(|tc| ToolCall {
441 id: tc.id,
442 call_type: "function".to_string(),
443 function: FunctionCall {
444 name: tc.function.name,
445 arguments: tc.function.arguments,
446 },
447 })
448 .collect()
449 });
450
451 ChatMessage {
452 role: Role::Assistant,
453 content,
454 name: None,
455 tool_calls,
456 tool_call_id: None,
457 }
458 }
459}
460
461#[async_trait]
462impl LLMProvider for OpenAIProvider {
463 fn name(&self) -> &str {
464 "openai"
465 }
466
467 fn default_model(&self) -> &str {
468 &self.config.default_model
469 }
470
471 fn supported_models(&self) -> Vec<&str> {
472 vec![
473 "gpt-4o",
474 "gpt-4o-mini",
475 "gpt-4-turbo",
476 "gpt-4",
477 "gpt-3.5-turbo",
478 "o1",
479 "o1-mini",
480 "o1-preview",
481 ]
482 }
483
484 fn supports_streaming(&self) -> bool {
485 true
486 }
487
488 fn supports_tools(&self) -> bool {
489 true
490 }
491
492 fn supports_vision(&self) -> bool {
493 true
494 }
495
496 fn supports_embedding(&self) -> bool {
497 true
498 }
499
500 async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
501 let messages = Self::convert_messages(&request.messages)?;
502
503 let model = if request.model.is_empty() {
504 self.config.default_model.clone()
505 } else {
506 request.model.clone()
507 };
508
509 let mut builder = CreateChatCompletionRequestArgs::default();
510 builder.model(&model).messages(messages);
511
512 if let Some(temp) = request.temperature {
514 builder.temperature(temp);
515 } else {
516 builder.temperature(self.config.default_temperature);
517 }
518
519 if let Some(max_tokens) = request.max_tokens {
520 builder.max_tokens(max_tokens);
521 }
522
523 if let Some(top_p) = request.top_p {
524 builder.top_p(top_p);
525 }
526
527 if let Some(ref stop) = request.stop {
528 builder.stop(stop.clone());
529 }
530
531 if let Some(freq_penalty) = request.frequency_penalty {
532 builder.frequency_penalty(freq_penalty);
533 }
534
535 if let Some(pres_penalty) = request.presence_penalty {
536 builder.presence_penalty(pres_penalty);
537 }
538
539 if let Some(ref user) = request.user {
540 builder.user(user);
541 }
542
543 if let Some(ref tools) = request.tools
545 && !tools.is_empty()
546 {
547 let converted_tools = Self::convert_tools(tools)?;
548 builder.tools(converted_tools);
549
550 if let Some(ref choice) = request.tool_choice {
552 let tc = match choice {
553 ToolChoice::Auto => ChatCompletionToolChoiceOption::Auto,
554 ToolChoice::None => ChatCompletionToolChoiceOption::None,
555 ToolChoice::Required => ChatCompletionToolChoiceOption::Required,
556 ToolChoice::Specific { function, .. } => ChatCompletionToolChoiceOption::Named(
557 async_openai::types::ChatCompletionNamedToolChoice {
558 r#type: ChatCompletionToolType::Function,
559 function: async_openai::types::FunctionName {
560 name: function.name.clone(),
561 },
562 },
563 ),
564 };
565 builder.tool_choice(tc);
566 }
567 }
568
569 if let Some(ref format) = request.response_format
571 && format.format_type == "json_object"
572 {
573 builder.response_format(async_openai::types::ResponseFormat::JsonObject);
574 }
575
576 let openai_request = builder
577 .build()
578 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
579
580 let response = self
581 .client
582 .chat()
583 .create(openai_request)
584 .await
585 .map_err(Self::convert_error)?;
586
587 Ok(Self::convert_response(response))
588 }
589
590 async fn chat_stream(&self, request: ChatCompletionRequest) -> LLMResult<ChatStream> {
591 let messages = Self::convert_messages(&request.messages)?;
592
593 let model = if request.model.is_empty() {
594 self.config.default_model.clone()
595 } else {
596 request.model.clone()
597 };
598
599 let mut builder = CreateChatCompletionRequestArgs::default();
600 builder.model(&model).messages(messages).stream(true);
601
602 if let Some(temp) = request.temperature {
603 builder.temperature(temp);
604 }
605
606 if let Some(max_tokens) = request.max_tokens {
607 builder.max_tokens(max_tokens);
608 }
609
610 if let Some(ref tools) = request.tools
612 && !tools.is_empty()
613 {
614 let converted_tools = Self::convert_tools(tools)?;
615 builder.tools(converted_tools);
616 }
617
618 let openai_request = builder
619 .build()
620 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
621
622 let stream = self
623 .client
624 .chat()
625 .create_stream(openai_request)
626 .await
627 .map_err(Self::convert_error)?;
628
629 let converted_stream = stream
631 .filter_map(|result| async move {
632 match result {
633 Ok(chunk) => Some(Ok(Self::convert_chunk(chunk))),
634 Err(e) => {
635 let err_str = e.to_string();
636 if err_str.contains("stream did not contain valid UTF-8") || err_str.contains("utf8") {
638 tracing::warn!("Skipping invalid UTF-8 chunk from stream (may happen with some OpenAI-compatible APIs)");
639 None
640 } else {
641 Some(Err(Self::convert_error(e)))
642 }
643 }
644 }
645 });
646
647 Ok(Box::pin(converted_stream))
648 }
649
650 async fn embedding(&self, request: EmbeddingRequest) -> LLMResult<EmbeddingResponse> {
651 use async_openai::types::CreateEmbeddingRequestArgs;
652
653 let input = match request.input {
654 EmbeddingInput::Single(s) => vec![s],
655 EmbeddingInput::Multiple(v) => v,
656 };
657
658 let openai_request = CreateEmbeddingRequestArgs::default()
659 .model(&request.model)
660 .input(input)
661 .build()
662 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
663
664 let response = self
665 .client
666 .embeddings()
667 .create(openai_request)
668 .await
669 .map_err(Self::convert_error)?;
670
671 let data: Vec<EmbeddingData> = response
672 .data
673 .into_iter()
674 .map(|d| EmbeddingData {
675 object: "embedding".to_string(),
676 index: d.index,
677 embedding: d.embedding,
678 })
679 .collect();
680
681 Ok(EmbeddingResponse {
682 object: "list".to_string(),
683 model: response.model,
684 data,
685 usage: EmbeddingUsage {
686 prompt_tokens: response.usage.prompt_tokens,
687 total_tokens: response.usage.total_tokens,
688 },
689 })
690 }
691
692 async fn health_check(&self) -> LLMResult<bool> {
693 let request = ChatCompletionRequest::new(&self.config.default_model)
695 .system("Say 'ok'")
696 .max_tokens(5);
697
698 match self.chat(request).await {
699 Ok(_) => Ok(true),
700 Err(_) => Ok(false),
701 }
702 }
703
704 async fn get_model_info(&self, model: &str) -> LLMResult<ModelInfo> {
705 let info = match model {
707 "gpt-4o" => ModelInfo {
708 id: "gpt-4o".to_string(),
709 name: "GPT-4o".to_string(),
710 description: Some("Most capable GPT-4 model with vision".to_string()),
711 context_window: Some(128000),
712 max_output_tokens: Some(16384),
713 training_cutoff: Some("2023-10".to_string()),
714 capabilities: ModelCapabilities {
715 streaming: true,
716 tools: true,
717 vision: true,
718 json_mode: true,
719 json_schema: true,
720 },
721 },
722 "gpt-4o-mini" => ModelInfo {
723 id: "gpt-4o-mini".to_string(),
724 name: "GPT-4o Mini".to_string(),
725 description: Some("Smaller, faster GPT-4o".to_string()),
726 context_window: Some(128000),
727 max_output_tokens: Some(16384),
728 training_cutoff: Some("2023-10".to_string()),
729 capabilities: ModelCapabilities {
730 streaming: true,
731 tools: true,
732 vision: true,
733 json_mode: true,
734 json_schema: true,
735 },
736 },
737 "gpt-4-turbo" => ModelInfo {
738 id: "gpt-4-turbo".to_string(),
739 name: "GPT-4 Turbo".to_string(),
740 description: Some("GPT-4 Turbo with vision".to_string()),
741 context_window: Some(128000),
742 max_output_tokens: Some(4096),
743 training_cutoff: Some("2023-12".to_string()),
744 capabilities: ModelCapabilities {
745 streaming: true,
746 tools: true,
747 vision: true,
748 json_mode: true,
749 json_schema: false,
750 },
751 },
752 "gpt-3.5-turbo" => ModelInfo {
753 id: "gpt-3.5-turbo".to_string(),
754 name: "GPT-3.5 Turbo".to_string(),
755 description: Some("Fast and cost-effective".to_string()),
756 context_window: Some(16385),
757 max_output_tokens: Some(4096),
758 training_cutoff: Some("2021-09".to_string()),
759 capabilities: ModelCapabilities {
760 streaming: true,
761 tools: true,
762 vision: false,
763 json_mode: true,
764 json_schema: false,
765 },
766 },
767 _ => ModelInfo {
768 id: model.to_string(),
769 name: model.to_string(),
770 description: None,
771 context_window: None,
772 max_output_tokens: None,
773 training_cutoff: None,
774 capabilities: ModelCapabilities::default(),
775 },
776 };
777
778 Ok(info)
779 }
780}
781
782impl OpenAIProvider {
783 fn convert_chunk(
785 chunk: async_openai::types::CreateChatCompletionStreamResponse,
786 ) -> ChatCompletionChunk {
787 let choices: Vec<ChunkChoice> = chunk
788 .choices
789 .into_iter()
790 .map(|choice| {
791 let delta = ChunkDelta {
792 role: choice.delta.role.map(|_| Role::Assistant),
793 content: choice.delta.content,
794 tool_calls: choice.delta.tool_calls.map(|calls| {
795 calls
796 .into_iter()
797 .map(|tc| ToolCallDelta {
798 index: tc.index,
799 id: tc.id,
800 call_type: Some("function".to_string()),
801 function: tc.function.map(|f| FunctionCallDelta {
802 name: f.name,
803 arguments: f.arguments,
804 }),
805 })
806 .collect()
807 }),
808 };
809
810 let finish_reason = choice.finish_reason.map(|r| match r {
811 async_openai::types::FinishReason::Stop => FinishReason::Stop,
812 async_openai::types::FinishReason::Length => FinishReason::Length,
813 async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
814 async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
815 async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
816 });
817
818 ChunkChoice {
819 index: choice.index,
820 delta,
821 finish_reason,
822 }
823 })
824 .collect();
825
826 ChatCompletionChunk {
827 id: chunk.id,
828 object: "chat.completion.chunk".to_string(),
829 created: chunk.created as u64,
830 model: chunk.model,
831 choices,
832 usage: chunk.usage.map(|u| Usage {
833 prompt_tokens: u.prompt_tokens,
834 completion_tokens: u.completion_tokens,
835 total_tokens: u.total_tokens,
836 }),
837 }
838 }
839
840 fn convert_error(err: async_openai::error::OpenAIError) -> LLMError {
842 match err {
843 async_openai::error::OpenAIError::ApiError(api_err) => {
844 let code = api_err.code.clone();
845 let message = api_err.message.clone();
846
847 if message.contains("rate limit") {
849 LLMError::RateLimited(message)
850 } else if message.contains("quota") || message.contains("billing") {
851 LLMError::QuotaExceeded(message)
852 } else if message.contains("model") && message.contains("not found") {
853 LLMError::ModelNotFound(message)
854 } else if message.contains("context") || message.contains("tokens") {
855 LLMError::ContextLengthExceeded(message)
856 } else if message.contains("content") && message.contains("filter") {
857 LLMError::ContentFiltered(message)
858 } else {
859 LLMError::ApiError { code, message }
860 }
861 }
862 async_openai::error::OpenAIError::Reqwest(e) => {
863 if e.is_timeout() {
864 LLMError::Timeout(e.to_string())
865 } else {
866 LLMError::NetworkError(e.to_string())
867 }
868 }
869 async_openai::error::OpenAIError::InvalidArgument(msg) => LLMError::ConfigError(msg),
870 _ => LLMError::Other(err.to_string()),
871 }
872 }
873
874 pub fn openai(api_key: impl Into<String>) -> OpenAIProvider {
876 OpenAIProvider::new(api_key)
877 }
878
879 pub fn openai_compatible(
881 base_url: impl Into<String>,
882 api_key: impl Into<String>,
883 model: impl Into<String>,
884 ) -> OpenAIProvider {
885 let config = OpenAIConfig::new(api_key)
886 .with_base_url(base_url)
887 .with_model(model);
888 OpenAIProvider::with_config(config)
889 }
890}
891
892#[cfg(test)]
893mod tests {
894 use super::*;
895
896 #[test]
897 fn test_config_builder() {
898 let config = OpenAIConfig::new("sk-test")
899 .with_base_url("http://localhost:8080")
900 .with_model("gpt-4")
901 .with_temperature(0.5)
902 .with_max_tokens(2048);
903
904 assert_eq!(config.api_key, "sk-test");
905 assert_eq!(config.base_url, Some("http://localhost:8080".to_string()));
906 assert_eq!(config.default_model, "gpt-4");
907 assert_eq!(config.default_temperature, 0.5);
908 assert_eq!(config.default_max_tokens, 2048);
909 }
910
911 #[test]
912 fn test_provider_name() {
913 let provider = OpenAIProvider::new("test-key");
914 assert_eq!(provider.name(), "openai");
915 }
916}