1use super::provider::{ChatStream, LLMProvider, ModelCapabilities, ModelInfo};
31use super::types::*;
32use async_openai::{
33 Client,
34 config::OpenAIConfig as AsyncOpenAIConfig,
35 types::{
36 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
37 ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
38 ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
39 ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
40 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
41 ChatCompletionToolArgs, ChatCompletionToolChoiceOption, ChatCompletionToolType,
42 CreateChatCompletionRequestArgs, FunctionObjectArgs, ImageDetail as OpenAIImageDetail,
43 ImageUrl as OpenAIImageUrl, InputAudio, InputAudioFormat,
44 },
45};
46use async_trait::async_trait;
47use futures::StreamExt;
48
49#[derive(Debug, Clone)]
51pub struct OpenAIConfig {
52 pub api_key: String,
54 pub base_url: Option<String>,
56 pub org_id: Option<String>,
58 pub default_model: String,
60 pub default_temperature: f32,
62 pub default_max_tokens: u32,
64 pub timeout_secs: u64,
66}
67
68impl Default for OpenAIConfig {
69 fn default() -> Self {
70 Self {
71 api_key: String::new(),
72 base_url: None,
73 org_id: None,
74 default_model: "gpt-4o".to_string(),
75 default_temperature: 0.7,
76 default_max_tokens: 4096,
77 timeout_secs: 60,
78 }
79 }
80}
81
82impl OpenAIConfig {
83 pub fn new(api_key: impl Into<String>) -> Self {
85 Self {
86 api_key: api_key.into(),
87 ..Default::default()
88 }
89 }
90
91 pub fn from_env() -> Self {
93 Self {
94 api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
95 base_url: std::env::var("OPENAI_BASE_URL").ok(),
96 default_model: std::env::var("OPENAI_MODEL").unwrap_or_default(),
97 ..Default::default()
98 }
99 }
100
101 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
103 self.base_url = Some(url.into());
104 self
105 }
106
107 pub fn with_model(mut self, model: impl Into<String>) -> Self {
109 self.default_model = model.into();
110 self
111 }
112
113 pub fn with_temperature(mut self, temp: f32) -> Self {
115 self.default_temperature = temp;
116 self
117 }
118
119 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
121 self.default_max_tokens = tokens;
122 self
123 }
124
125 pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
127 self.org_id = Some(org_id.into());
128 self
129 }
130
131 pub fn with_timeout(mut self, secs: u64) -> Self {
133 self.timeout_secs = secs;
134 self
135 }
136}
137
138pub struct OpenAIProvider {
142 client: Client<AsyncOpenAIConfig>,
143 config: OpenAIConfig,
144}
145
146impl OpenAIProvider {
147 pub fn new(api_key: impl Into<String>) -> Self {
149 let config = OpenAIConfig::new(api_key);
150 Self::with_config(config)
151 }
152
153 pub fn from_env() -> Self {
155 Self::with_config(OpenAIConfig::from_env())
156 }
157
158 pub fn with_config(config: OpenAIConfig) -> Self {
160 let mut openai_config = AsyncOpenAIConfig::new().with_api_key(&config.api_key);
161
162 if let Some(ref base_url) = config.base_url {
163 openai_config = openai_config.with_api_base(base_url);
164 }
165
166 if let Some(ref org_id) = config.org_id {
167 openai_config = openai_config.with_org_id(org_id);
168 }
169
170 let client = Client::with_config(openai_config);
171
172 Self { client, config }
173 }
174
175 pub fn azure(
177 endpoint: impl Into<String>,
178 api_key: impl Into<String>,
179 deployment: impl Into<String>,
180 ) -> Self {
181 let endpoint = endpoint.into();
182 let deployment = deployment.into();
183
184 let base_url = format!(
186 "{}/openai/deployments/{}",
187 endpoint.trim_end_matches('/'),
188 deployment
189 );
190
191 let config = OpenAIConfig::new(api_key)
192 .with_base_url(base_url)
193 .with_model(deployment);
194
195 Self::with_config(config)
196 }
197
198 pub fn local(base_url: impl Into<String>, model: impl Into<String>) -> Self {
200 let config = OpenAIConfig::new("not-needed")
201 .with_base_url(base_url)
202 .with_model(model);
203
204 Self::with_config(config)
205 }
206
207 pub fn client(&self) -> &Client<AsyncOpenAIConfig> {
209 &self.client
210 }
211
212 pub fn config(&self) -> &OpenAIConfig {
214 &self.config
215 }
216
217 fn convert_messages(
219 messages: &[ChatMessage],
220 ) -> Result<Vec<ChatCompletionRequestMessage>, LLMError> {
221 messages.iter().map(Self::convert_message).collect()
222 }
223
224 fn convert_message(msg: &ChatMessage) -> Result<ChatCompletionRequestMessage, LLMError> {
226 let text_only_content = msg
227 .content
228 .as_ref()
229 .map(|c| match c {
230 MessageContent::Text(s) => s.clone(),
231 MessageContent::Parts(parts) => parts
232 .iter()
233 .filter_map(|p| match p {
234 ContentPart::Text { text } => Some(text.clone()),
235 _ => None,
236 })
237 .collect::<Vec<_>>()
238 .join("\n"),
239 })
240 .unwrap_or_default();
241
242 match msg.role {
243 Role::System => Ok(ChatCompletionRequestSystemMessageArgs::default()
244 .content(text_only_content)
245 .build()
246 .map_err(|e| LLMError::Other(e.to_string()))?
247 .into()),
248 Role::User => {
249 let content = match msg.content.as_ref() {
250 Some(MessageContent::Text(s)) => {
251 ChatCompletionRequestUserMessageContent::Text(s.clone())
252 }
253 Some(MessageContent::Parts(parts)) => {
254 let mut out = Vec::new();
255 for part in parts {
256 match part {
257 ContentPart::Text { text } => {
258 out.push(ChatCompletionRequestUserMessageContentPart::Text(
259 ChatCompletionRequestMessageContentPartText {
260 text: text.clone(),
261 },
262 ));
263 }
264 ContentPart::Image { image_url } => {
265 let detail = image_url.detail.as_ref().map(|d| match d {
266 ImageDetail::Auto => OpenAIImageDetail::Auto,
267 ImageDetail::Low => OpenAIImageDetail::Low,
268 ImageDetail::High => OpenAIImageDetail::High,
269 });
270 let image_part = ChatCompletionRequestMessageContentPartImage {
271 image_url: OpenAIImageUrl {
272 url: image_url.url.clone(),
273 detail,
274 },
275 };
276 out.push(
277 ChatCompletionRequestUserMessageContentPart::ImageUrl(
278 image_part,
279 ),
280 );
281 }
282 ContentPart::Audio { audio } => {
283 let format = match audio.format.to_lowercase().as_str() {
284 "wav" => InputAudioFormat::Wav,
285 _ => InputAudioFormat::Mp3,
286 };
287 let audio_part = ChatCompletionRequestMessageContentPartAudio {
288 input_audio: InputAudio {
289 data: audio.data.clone(),
290 format,
291 },
292 };
293 out.push(
294 ChatCompletionRequestUserMessageContentPart::InputAudio(
295 audio_part,
296 ),
297 );
298 }
299 }
300 }
301 ChatCompletionRequestUserMessageContent::Array(out)
302 }
303 None => ChatCompletionRequestUserMessageContent::Text(String::new()),
304 };
305
306 Ok(ChatCompletionRequestUserMessageArgs::default()
307 .content(content)
308 .build()
309 .map_err(|e| LLMError::Other(e.to_string()))?
310 .into())
311 }
312 Role::Assistant => {
313 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
314 if !text_only_content.is_empty() {
315 builder.content(text_only_content);
316 }
317
318 if let Some(ref tool_calls) = msg.tool_calls {
320 let converted_calls: Vec<_> = tool_calls
321 .iter()
322 .map(|tc| async_openai::types::ChatCompletionMessageToolCall {
323 id: tc.id.clone(),
324 r#type: ChatCompletionToolType::Function,
325 function: async_openai::types::FunctionCall {
326 name: tc.function.name.clone(),
327 arguments: tc.function.arguments.clone(),
328 },
329 })
330 .collect();
331 builder.tool_calls(converted_calls);
332 }
333
334 Ok(builder
335 .build()
336 .map_err(|e| LLMError::Other(e.to_string()))?
337 .into())
338 }
339 Role::Tool => {
340 let tool_call_id = msg
341 .tool_call_id
342 .clone()
343 .unwrap_or_else(|| "unknown".to_string());
344
345 Ok(ChatCompletionRequestToolMessageArgs::default()
346 .tool_call_id(tool_call_id)
347 .content(text_only_content)
348 .build()
349 .map_err(|e| LLMError::Other(e.to_string()))?
350 .into())
351 }
352 }
353 }
354
355 fn convert_tools(
357 tools: &[Tool],
358 ) -> Result<Vec<async_openai::types::ChatCompletionTool>, LLMError> {
359 tools
360 .iter()
361 .map(|tool| {
362 let function = FunctionObjectArgs::default()
363 .name(&tool.function.name)
364 .description(tool.function.description.clone().unwrap_or_default())
365 .parameters(
366 tool.function
367 .parameters
368 .clone()
369 .unwrap_or(serde_json::json!({})),
370 )
371 .build()
372 .map_err(|e| LLMError::Other(e.to_string()))?;
373
374 ChatCompletionToolArgs::default()
375 .r#type(ChatCompletionToolType::Function)
376 .function(function)
377 .build()
378 .map_err(|e| LLMError::Other(e.to_string()))
379 })
380 .collect()
381 }
382
383 fn convert_response(
385 response: async_openai::types::CreateChatCompletionResponse,
386 ) -> ChatCompletionResponse {
387 let choices: Vec<Choice> = response
388 .choices
389 .into_iter()
390 .map(|choice| {
391 let message = Self::convert_response_message(choice.message);
392 let finish_reason = choice.finish_reason.map(|r| match r {
393 async_openai::types::FinishReason::Stop => FinishReason::Stop,
394 async_openai::types::FinishReason::Length => FinishReason::Length,
395 async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
396 async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
397 async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
398 });
399
400 Choice {
401 index: choice.index,
402 message,
403 finish_reason,
404 logprobs: None,
405 }
406 })
407 .collect();
408
409 let usage = response.usage.map(|u| Usage {
410 prompt_tokens: u.prompt_tokens,
411 completion_tokens: u.completion_tokens,
412 total_tokens: u.total_tokens,
413 });
414
415 ChatCompletionResponse {
416 id: response.id,
417 object: response.object,
418 created: response.created as u64,
419 model: response.model,
420 choices,
421 usage,
422 system_fingerprint: response.system_fingerprint,
423 }
424 }
425
426 fn convert_response_message(
428 msg: async_openai::types::ChatCompletionResponseMessage,
429 ) -> ChatMessage {
430 let content = msg.content.map(MessageContent::Text);
431
432 let tool_calls = msg.tool_calls.map(|calls| {
433 calls
434 .into_iter()
435 .map(|tc| ToolCall {
436 id: tc.id,
437 call_type: "function".to_string(),
438 function: FunctionCall {
439 name: tc.function.name,
440 arguments: tc.function.arguments,
441 },
442 })
443 .collect()
444 });
445
446 ChatMessage {
447 role: Role::Assistant,
448 content,
449 name: None,
450 tool_calls,
451 tool_call_id: None,
452 }
453 }
454}
455
456#[async_trait]
457impl LLMProvider for OpenAIProvider {
458 fn name(&self) -> &str {
459 "openai"
460 }
461
462 fn default_model(&self) -> &str {
463 &self.config.default_model
464 }
465
466 fn supported_models(&self) -> Vec<&str> {
467 vec![
468 "gpt-4o",
469 "gpt-4o-mini",
470 "gpt-4-turbo",
471 "gpt-4",
472 "gpt-3.5-turbo",
473 "o1",
474 "o1-mini",
475 "o1-preview",
476 ]
477 }
478
479 fn supports_streaming(&self) -> bool {
480 true
481 }
482
483 fn supports_tools(&self) -> bool {
484 true
485 }
486
487 fn supports_vision(&self) -> bool {
488 true
489 }
490
491 fn supports_embedding(&self) -> bool {
492 true
493 }
494
495 async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
496 let messages = Self::convert_messages(&request.messages)?;
497
498 let model = if request.model.is_empty() {
499 self.config.default_model.clone()
500 } else {
501 request.model.clone()
502 };
503
504 let mut builder = CreateChatCompletionRequestArgs::default();
505 builder.model(&model).messages(messages);
506
507 if let Some(temp) = request.temperature {
509 builder.temperature(temp);
510 } else {
511 builder.temperature(self.config.default_temperature);
512 }
513
514 if let Some(max_tokens) = request.max_tokens {
515 builder.max_tokens(max_tokens);
516 }
517
518 if let Some(top_p) = request.top_p {
519 builder.top_p(top_p);
520 }
521
522 if let Some(ref stop) = request.stop {
523 builder.stop(stop.clone());
524 }
525
526 if let Some(freq_penalty) = request.frequency_penalty {
527 builder.frequency_penalty(freq_penalty);
528 }
529
530 if let Some(pres_penalty) = request.presence_penalty {
531 builder.presence_penalty(pres_penalty);
532 }
533
534 if let Some(ref user) = request.user {
535 builder.user(user);
536 }
537
538 if let Some(ref tools) = request.tools
540 && !tools.is_empty()
541 {
542 let converted_tools = Self::convert_tools(tools)?;
543 builder.tools(converted_tools);
544
545 if let Some(ref choice) = request.tool_choice {
547 let tc = match choice {
548 ToolChoice::Auto => ChatCompletionToolChoiceOption::Auto,
549 ToolChoice::None => ChatCompletionToolChoiceOption::None,
550 ToolChoice::Required => ChatCompletionToolChoiceOption::Required,
551 ToolChoice::Specific { function, .. } => ChatCompletionToolChoiceOption::Named(
552 async_openai::types::ChatCompletionNamedToolChoice {
553 r#type: ChatCompletionToolType::Function,
554 function: async_openai::types::FunctionName {
555 name: function.name.clone(),
556 },
557 },
558 ),
559 };
560 builder.tool_choice(tc);
561 }
562 }
563
564 if let Some(ref format) = request.response_format
566 && format.format_type == "json_object"
567 {
568 builder.response_format(async_openai::types::ResponseFormat::JsonObject);
569 }
570
571 let openai_request = builder
572 .build()
573 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
574
575 let response = self
576 .client
577 .chat()
578 .create(openai_request)
579 .await
580 .map_err(Self::convert_error)?;
581
582 Ok(Self::convert_response(response))
583 }
584
585 async fn chat_stream(&self, request: ChatCompletionRequest) -> LLMResult<ChatStream> {
586 let messages = Self::convert_messages(&request.messages)?;
587
588 let model = if request.model.is_empty() {
589 self.config.default_model.clone()
590 } else {
591 request.model.clone()
592 };
593
594 let mut builder = CreateChatCompletionRequestArgs::default();
595 builder.model(&model).messages(messages).stream(true);
596
597 if let Some(temp) = request.temperature {
598 builder.temperature(temp);
599 }
600
601 if let Some(max_tokens) = request.max_tokens {
602 builder.max_tokens(max_tokens);
603 }
604
605 if let Some(ref tools) = request.tools
607 && !tools.is_empty()
608 {
609 let converted_tools = Self::convert_tools(tools)?;
610 builder.tools(converted_tools);
611 }
612
613 let openai_request = builder
614 .build()
615 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
616
617 let stream = self
618 .client
619 .chat()
620 .create_stream(openai_request)
621 .await
622 .map_err(Self::convert_error)?;
623
624 let converted_stream = stream
626 .filter_map(|result| async move {
627 match result {
628 Ok(chunk) => Some(Ok(Self::convert_chunk(chunk))),
629 Err(e) => {
630 let err_str = e.to_string();
631 if err_str.contains("stream did not contain valid UTF-8") || err_str.contains("utf8") {
633 tracing::warn!("Skipping invalid UTF-8 chunk from stream (may happen with some OpenAI-compatible APIs)");
634 None
635 } else {
636 Some(Err(Self::convert_error(e)))
637 }
638 }
639 }
640 });
641
642 Ok(Box::pin(converted_stream))
643 }
644
645 async fn embedding(&self, request: EmbeddingRequest) -> LLMResult<EmbeddingResponse> {
646 use async_openai::types::CreateEmbeddingRequestArgs;
647
648 let input = match request.input {
649 EmbeddingInput::Single(s) => vec![s],
650 EmbeddingInput::Multiple(v) => v,
651 };
652
653 let openai_request = CreateEmbeddingRequestArgs::default()
654 .model(&request.model)
655 .input(input)
656 .build()
657 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
658
659 let response = self
660 .client
661 .embeddings()
662 .create(openai_request)
663 .await
664 .map_err(Self::convert_error)?;
665
666 let data: Vec<EmbeddingData> = response
667 .data
668 .into_iter()
669 .map(|d| EmbeddingData {
670 object: "embedding".to_string(),
671 index: d.index,
672 embedding: d.embedding,
673 })
674 .collect();
675
676 Ok(EmbeddingResponse {
677 object: "list".to_string(),
678 model: response.model,
679 data,
680 usage: EmbeddingUsage {
681 prompt_tokens: response.usage.prompt_tokens,
682 total_tokens: response.usage.total_tokens,
683 },
684 })
685 }
686
687 async fn health_check(&self) -> LLMResult<bool> {
688 let request = ChatCompletionRequest::new(&self.config.default_model)
690 .system("Say 'ok'")
691 .max_tokens(5);
692
693 match self.chat(request).await {
694 Ok(_) => Ok(true),
695 Err(_) => Ok(false),
696 }
697 }
698
699 async fn get_model_info(&self, model: &str) -> LLMResult<ModelInfo> {
700 let info = match model {
702 "gpt-4o" => ModelInfo {
703 id: "gpt-4o".to_string(),
704 name: "GPT-4o".to_string(),
705 description: Some("Most capable GPT-4 model with vision".to_string()),
706 context_window: Some(128000),
707 max_output_tokens: Some(16384),
708 training_cutoff: Some("2023-10".to_string()),
709 capabilities: ModelCapabilities {
710 streaming: true,
711 tools: true,
712 vision: true,
713 json_mode: true,
714 json_schema: true,
715 },
716 },
717 "gpt-4o-mini" => ModelInfo {
718 id: "gpt-4o-mini".to_string(),
719 name: "GPT-4o Mini".to_string(),
720 description: Some("Smaller, faster GPT-4o".to_string()),
721 context_window: Some(128000),
722 max_output_tokens: Some(16384),
723 training_cutoff: Some("2023-10".to_string()),
724 capabilities: ModelCapabilities {
725 streaming: true,
726 tools: true,
727 vision: true,
728 json_mode: true,
729 json_schema: true,
730 },
731 },
732 "gpt-4-turbo" => ModelInfo {
733 id: "gpt-4-turbo".to_string(),
734 name: "GPT-4 Turbo".to_string(),
735 description: Some("GPT-4 Turbo with vision".to_string()),
736 context_window: Some(128000),
737 max_output_tokens: Some(4096),
738 training_cutoff: Some("2023-12".to_string()),
739 capabilities: ModelCapabilities {
740 streaming: true,
741 tools: true,
742 vision: true,
743 json_mode: true,
744 json_schema: false,
745 },
746 },
747 "gpt-3.5-turbo" => ModelInfo {
748 id: "gpt-3.5-turbo".to_string(),
749 name: "GPT-3.5 Turbo".to_string(),
750 description: Some("Fast and cost-effective".to_string()),
751 context_window: Some(16385),
752 max_output_tokens: Some(4096),
753 training_cutoff: Some("2021-09".to_string()),
754 capabilities: ModelCapabilities {
755 streaming: true,
756 tools: true,
757 vision: false,
758 json_mode: true,
759 json_schema: false,
760 },
761 },
762 _ => ModelInfo {
763 id: model.to_string(),
764 name: model.to_string(),
765 description: None,
766 context_window: None,
767 max_output_tokens: None,
768 training_cutoff: None,
769 capabilities: ModelCapabilities::default(),
770 },
771 };
772
773 Ok(info)
774 }
775}
776
777impl OpenAIProvider {
778 fn convert_chunk(
780 chunk: async_openai::types::CreateChatCompletionStreamResponse,
781 ) -> ChatCompletionChunk {
782 let choices: Vec<ChunkChoice> = chunk
783 .choices
784 .into_iter()
785 .map(|choice| {
786 let delta = ChunkDelta {
787 role: choice.delta.role.map(|_| Role::Assistant),
788 content: choice.delta.content,
789 tool_calls: choice.delta.tool_calls.map(|calls| {
790 calls
791 .into_iter()
792 .map(|tc| ToolCallDelta {
793 index: tc.index,
794 id: tc.id,
795 call_type: Some("function".to_string()),
796 function: tc.function.map(|f| FunctionCallDelta {
797 name: f.name,
798 arguments: f.arguments,
799 }),
800 })
801 .collect()
802 }),
803 };
804
805 let finish_reason = choice.finish_reason.map(|r| match r {
806 async_openai::types::FinishReason::Stop => FinishReason::Stop,
807 async_openai::types::FinishReason::Length => FinishReason::Length,
808 async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
809 async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
810 async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
811 });
812
813 ChunkChoice {
814 index: choice.index,
815 delta,
816 finish_reason,
817 }
818 })
819 .collect();
820
821 ChatCompletionChunk {
822 id: chunk.id,
823 object: "chat.completion.chunk".to_string(),
824 created: chunk.created as u64,
825 model: chunk.model,
826 choices,
827 usage: chunk.usage.map(|u| Usage {
828 prompt_tokens: u.prompt_tokens,
829 completion_tokens: u.completion_tokens,
830 total_tokens: u.total_tokens,
831 }),
832 }
833 }
834
835 fn convert_error(err: async_openai::error::OpenAIError) -> LLMError {
837 match err {
838 async_openai::error::OpenAIError::ApiError(api_err) => {
839 let code = api_err.code.clone();
840 let message = api_err.message.clone();
841
842 if message.contains("rate limit") {
844 LLMError::RateLimited(message)
845 } else if message.contains("quota") || message.contains("billing") {
846 LLMError::QuotaExceeded(message)
847 } else if message.contains("model") && message.contains("not found") {
848 LLMError::ModelNotFound(message)
849 } else if message.contains("context") || message.contains("tokens") {
850 LLMError::ContextLengthExceeded(message)
851 } else if message.contains("content") && message.contains("filter") {
852 LLMError::ContentFiltered(message)
853 } else {
854 LLMError::ApiError { code, message }
855 }
856 }
857 async_openai::error::OpenAIError::Reqwest(e) => {
858 if e.is_timeout() {
859 LLMError::Timeout(e.to_string())
860 } else {
861 LLMError::NetworkError(e.to_string())
862 }
863 }
864 async_openai::error::OpenAIError::InvalidArgument(msg) => LLMError::ConfigError(msg),
865 _ => LLMError::Other(err.to_string()),
866 }
867 }
868
869 pub fn openai(api_key: impl Into<String>) -> OpenAIProvider {
871 OpenAIProvider::new(api_key)
872 }
873
874 pub fn openai_compatible(
876 base_url: impl Into<String>,
877 api_key: impl Into<String>,
878 model: impl Into<String>,
879 ) -> OpenAIProvider {
880 let config = OpenAIConfig::new(api_key)
881 .with_base_url(base_url)
882 .with_model(model);
883 OpenAIProvider::with_config(config)
884 }
885}
886
887#[cfg(test)]
888mod tests {
889 use super::*;
890
891 #[test]
892 fn test_config_builder() {
893 let config = OpenAIConfig::new("sk-test")
894 .with_base_url("http://localhost:8080")
895 .with_model("gpt-4")
896 .with_temperature(0.5)
897 .with_max_tokens(2048);
898
899 assert_eq!(config.api_key, "sk-test");
900 assert_eq!(config.base_url, Some("http://localhost:8080".to_string()));
901 assert_eq!(config.default_model, "gpt-4");
902 assert_eq!(config.default_temperature, 0.5);
903 assert_eq!(config.default_max_tokens, 2048);
904 }
905
906 #[test]
907 fn test_provider_name() {
908 let provider = OpenAIProvider::new("test-key");
909 assert_eq!(provider.name(), "openai");
910 }
911}