1use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_openai::{
27 config::OpenAIConfig,
28 types::chat::{
29 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
30 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
31 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
32 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
33 CreateChatCompletionRequestArgs, FunctionCall, FunctionObject,
34 },
35 Client,
36};
37use async_trait::async_trait;
38use futures::StreamExt;
39
40pub struct OpenAIClient {
42 client: Client<OpenAIConfig>,
43 model: String,
44 params: ModelParams,
45}
46
47impl OpenAIClient {
48 pub fn new(api_key: String, api_base: String, model: String) -> Self {
56 Self::with_params(api_key, api_base, model, ModelParams::default())
57 }
58
59 pub fn with_params(
68 api_key: String,
69 api_base: String,
70 model: String,
71 params: ModelParams,
72 ) -> Self {
73 let config = OpenAIConfig::new()
74 .with_api_key(api_key)
75 .with_api_base(api_base);
76
77 Self {
78 client: Client::with_config(config),
79 model,
80 params,
81 }
82 }
83
84 fn convert_tool(tool: &ToolDefinition) -> ChatCompletionTools {
86 ChatCompletionTools::Function(ChatCompletionTool {
87 function: FunctionObject {
88 name: tool.name.clone(),
89 description: Some(tool.description.clone()),
90 parameters: Some(tool.parameters.clone()),
91 strict: None,
92 },
93 })
94 }
95
96 fn extract_tool_calls(tool_calls: &[ChatCompletionMessageToolCalls]) -> Vec<ToolCall> {
98 tool_calls
99 .iter()
100 .filter_map(|wrapper| match wrapper {
101 ChatCompletionMessageToolCalls::Function(call) => Some(ToolCall {
102 id: call.id.clone(),
103 name: call.function.name.clone(),
104 arguments: serde_json::from_str(&call.function.arguments)
105 .unwrap_or(serde_json::json!({})),
106 }),
107 ChatCompletionMessageToolCalls::Custom(_) => None,
108 })
109 .collect()
110 }
111
112 fn convert_conversation_message(
114 &self,
115 msg: &ConversationMessage,
116 ) -> Result<ChatCompletionRequestMessage> {
117 match msg.role {
118 MessageRole::System => {
119 let system_msg = ChatCompletionRequestSystemMessageArgs::default()
120 .content(msg.content.clone())
121 .build()
122 .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
123 Ok(ChatCompletionRequestMessage::System(system_msg))
124 }
125 MessageRole::User => {
126 let user_msg = ChatCompletionRequestUserMessageArgs::default()
127 .content(msg.content.clone())
128 .build()
129 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
130 Ok(ChatCompletionRequestMessage::User(user_msg))
131 }
132 MessageRole::Assistant => {
133 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
134
135 if !msg.content.is_empty() {
136 builder.content(msg.content.clone());
137 }
138
139 if !msg.tool_calls.is_empty() {
141 let openai_tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
142 .tool_calls
143 .iter()
144 .map(|tc| {
145 ChatCompletionMessageToolCalls::Function(
146 ChatCompletionMessageToolCall {
147 id: tc.id.clone(),
148 function: FunctionCall {
149 name: tc.name.clone(),
150 arguments: serde_json::to_string(&tc.arguments)
151 .unwrap_or_else(|_| "{}".to_string()),
152 },
153 },
154 )
155 })
156 .collect();
157 builder.tool_calls(openai_tool_calls);
158 }
159
160 let assistant_msg = builder.build().map_err(|e| {
161 AppError::LLM(format!("Failed to build assistant message: {}", e))
162 })?;
163 Ok(ChatCompletionRequestMessage::Assistant(assistant_msg))
164 }
165 MessageRole::Tool => {
166 let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
167 AppError::LLM("Tool message must have a tool_call_id".to_string())
168 })?;
169
170 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
171 .tool_call_id(tool_call_id)
172 .content(msg.content.clone())
173 .build()
174 .map_err(|e| AppError::LLM(format!("Failed to build tool message: {}", e)))?;
175 Ok(ChatCompletionRequestMessage::Tool(tool_msg))
176 }
177 }
178 }
179}
180
181#[async_trait]
182impl LLMClient for OpenAIClient {
183 async fn generate(&self, prompt: &str) -> Result<String> {
184 let message = ChatCompletionRequestUserMessageArgs::default()
185 .content(prompt)
186 .build()
187 .map_err(|e| AppError::LLM(format!("Failed to build message: {}", e)))?;
188
189 let mut builder = CreateChatCompletionRequestArgs::default();
190 builder.model(&self.model);
191 builder.messages(vec![ChatCompletionRequestMessage::User(message)]);
192
193 if let Some(temp) = self.params.temperature {
195 builder.temperature(temp);
196 }
197 if let Some(max_tokens) = self.params.max_tokens {
198 builder.max_completion_tokens(max_tokens);
199 }
200 if let Some(top_p) = self.params.top_p {
201 builder.top_p(top_p);
202 }
203 if let Some(freq_penalty) = self.params.frequency_penalty {
204 builder.frequency_penalty(freq_penalty);
205 }
206 if let Some(pres_penalty) = self.params.presence_penalty {
207 builder.presence_penalty(pres_penalty);
208 }
209
210 let request = builder
211 .build()
212 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
213
214 let response = self
215 .client
216 .chat()
217 .create(request)
218 .await
219 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
220
221 response
222 .choices
223 .first()
224 .and_then(|choice| choice.message.content.clone())
225 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
226 }
227
228 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
229 let system_message = ChatCompletionRequestSystemMessageArgs::default()
230 .content(system)
231 .build()
232 .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
233
234 let user_message = ChatCompletionRequestUserMessageArgs::default()
235 .content(prompt)
236 .build()
237 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
238
239 let mut builder = CreateChatCompletionRequestArgs::default();
240 builder.model(&self.model);
241 builder.messages(vec![
242 ChatCompletionRequestMessage::System(system_message),
243 ChatCompletionRequestMessage::User(user_message),
244 ]);
245
246 if let Some(temp) = self.params.temperature {
248 builder.temperature(temp);
249 }
250 if let Some(max_tokens) = self.params.max_tokens {
251 builder.max_completion_tokens(max_tokens);
252 }
253 if let Some(top_p) = self.params.top_p {
254 builder.top_p(top_p);
255 }
256 if let Some(freq_penalty) = self.params.frequency_penalty {
257 builder.frequency_penalty(freq_penalty);
258 }
259 if let Some(pres_penalty) = self.params.presence_penalty {
260 builder.presence_penalty(pres_penalty);
261 }
262
263 let request = builder
264 .build()
265 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
266
267 let response = self
268 .client
269 .chat()
270 .create(request)
271 .await
272 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
273
274 response
275 .choices
276 .first()
277 .and_then(|choice| choice.message.content.clone())
278 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
279 }
280
281 async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<LLMResponse> {
282 let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
283 messages
284 .iter()
285 .map(|(role, content)| {
286 match role.as_str() {
287 "system" => {
288 let msg = ChatCompletionRequestSystemMessageArgs::default()
289 .content(content.as_str())
290 .build()
291 .map_err(|e| {
292 AppError::LLM(format!("Failed to build system message: {}", e))
293 })?;
294 Ok(ChatCompletionRequestMessage::System(msg))
295 }
296 "assistant" => {
297 let msg = ChatCompletionRequestAssistantMessageArgs::default()
298 .content(content.as_str())
299 .build()
300 .map_err(|e| {
301 AppError::LLM(format!(
302 "Failed to build assistant message: {}",
303 e
304 ))
305 })?;
306 Ok(ChatCompletionRequestMessage::Assistant(msg))
307 }
308 _ => {
309 let msg = ChatCompletionRequestUserMessageArgs::default()
311 .content(content.as_str())
312 .build()
313 .map_err(|e| {
314 AppError::LLM(format!("Failed to build user message: {}", e))
315 })?;
316 Ok(ChatCompletionRequestMessage::User(msg))
317 }
318 }
319 })
320 .collect();
321
322 let mut builder = CreateChatCompletionRequestArgs::default();
323 builder.model(&self.model);
324 builder.messages(chat_messages?);
325
326 if let Some(temp) = self.params.temperature {
328 builder.temperature(temp);
329 }
330 if let Some(max_tokens) = self.params.max_tokens {
331 builder.max_completion_tokens(max_tokens);
332 }
333 if let Some(top_p) = self.params.top_p {
334 builder.top_p(top_p);
335 }
336 if let Some(freq_penalty) = self.params.frequency_penalty {
337 builder.frequency_penalty(freq_penalty);
338 }
339 if let Some(pres_penalty) = self.params.presence_penalty {
340 builder.presence_penalty(pres_penalty);
341 }
342
343 let request = builder
344 .build()
345 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
346
347 let response = self
348 .client
349 .chat()
350 .create(request)
351 .await
352 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
353
354 let content = response
355 .choices
356 .first()
357 .and_then(|choice| choice.message.content.clone())
358 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
359
360 #[allow(clippy::unnecessary_cast)]
361 let usage = response
362 .usage
363 .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
364
365 Ok(LLMResponse {
366 content,
367 tool_calls: vec![],
368 finish_reason: "stop".to_string(),
369 usage,
370 })
371 }
372
373 async fn generate_with_tools(
374 &self,
375 prompt: &str,
376 tools: &[ToolDefinition],
377 ) -> Result<LLMResponse> {
378 let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
379
380 let user_message = ChatCompletionRequestUserMessageArgs::default()
381 .content(prompt)
382 .build()
383 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
384
385 let mut builder = CreateChatCompletionRequestArgs::default();
386 builder.model(&self.model);
387 builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
388 builder.tools(openai_tools);
389
390 if let Some(temp) = self.params.temperature {
392 builder.temperature(temp);
393 }
394 if let Some(max_tokens) = self.params.max_tokens {
395 builder.max_completion_tokens(max_tokens);
396 }
397 if let Some(top_p) = self.params.top_p {
398 builder.top_p(top_p);
399 }
400 if let Some(freq_penalty) = self.params.frequency_penalty {
401 builder.frequency_penalty(freq_penalty);
402 }
403 if let Some(pres_penalty) = self.params.presence_penalty {
404 builder.presence_penalty(pres_penalty);
405 }
406
407 let request = builder
408 .build()
409 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
410
411 let response = self
412 .client
413 .chat()
414 .create(request)
415 .await
416 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
417
418 let choice = response
419 .choices
420 .first()
421 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
422
423 let content = choice.message.content.clone().unwrap_or_default();
424
425 let finish_reason = choice
426 .finish_reason
427 .as_ref()
428 .map(|r| format!("{:?}", r).to_lowercase())
429 .unwrap_or_else(|| "stop".to_string());
430
431 let tool_calls = choice
432 .message
433 .tool_calls
434 .as_ref()
435 .map(|calls| Self::extract_tool_calls(calls))
436 .unwrap_or_default();
437
438 #[allow(clippy::unnecessary_cast)]
440 let usage = response
441 .usage
442 .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
443
444 Ok(LLMResponse {
445 content,
446 tool_calls,
447 finish_reason,
448 usage,
449 })
450 }
451
452 async fn generate_with_tools_and_history(
453 &self,
454 messages: &[ConversationMessage],
455 tools: &[ToolDefinition],
456 ) -> Result<LLMResponse> {
457 let openai_messages: Vec<ChatCompletionRequestMessage> = messages
459 .iter()
460 .map(|msg| self.convert_conversation_message(msg))
461 .collect::<Result<Vec<_>>>()?;
462
463 let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
465
466 let mut builder = CreateChatCompletionRequestArgs::default();
467 builder.model(&self.model);
468 builder.messages(openai_messages);
469
470 if !openai_tools.is_empty() {
471 builder.tools(openai_tools);
472 }
473
474 if let Some(temp) = self.params.temperature {
476 builder.temperature(temp);
477 }
478 if let Some(max_tokens) = self.params.max_tokens {
479 builder.max_completion_tokens(max_tokens);
480 }
481 if let Some(top_p) = self.params.top_p {
482 builder.top_p(top_p);
483 }
484 if let Some(freq_penalty) = self.params.frequency_penalty {
485 builder.frequency_penalty(freq_penalty);
486 }
487 if let Some(pres_penalty) = self.params.presence_penalty {
488 builder.presence_penalty(pres_penalty);
489 }
490
491 let request = builder
492 .build()
493 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
494
495 let response = self
496 .client
497 .chat()
498 .create(request)
499 .await
500 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
501
502 let choice = response
503 .choices
504 .first()
505 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
506
507 let content = choice.message.content.clone().unwrap_or_default();
508
509 let finish_reason = choice
510 .finish_reason
511 .as_ref()
512 .map(|r| format!("{:?}", r).to_lowercase())
513 .unwrap_or_else(|| "stop".to_string());
514
515 let tool_calls = choice
516 .message
517 .tool_calls
518 .as_ref()
519 .map(|calls| Self::extract_tool_calls(calls))
520 .unwrap_or_default();
521
522 #[allow(clippy::unnecessary_cast)]
523 let usage = response
524 .usage
525 .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
526
527 Ok(LLMResponse {
528 content,
529 tool_calls,
530 finish_reason,
531 usage,
532 })
533 }
534
535 async fn stream(
536 &self,
537 prompt: &str,
538 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
539 let user_message = ChatCompletionRequestUserMessageArgs::default()
540 .content(prompt)
541 .build()
542 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
543
544 let mut builder = CreateChatCompletionRequestArgs::default();
545 builder.model(&self.model);
546 builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
547
548 if let Some(temp) = self.params.temperature {
550 builder.temperature(temp);
551 }
552 if let Some(max_tokens) = self.params.max_tokens {
553 builder.max_completion_tokens(max_tokens);
554 }
555 if let Some(top_p) = self.params.top_p {
556 builder.top_p(top_p);
557 }
558 if let Some(freq_penalty) = self.params.frequency_penalty {
559 builder.frequency_penalty(freq_penalty);
560 }
561 if let Some(pres_penalty) = self.params.presence_penalty {
562 builder.presence_penalty(pres_penalty);
563 }
564
565 let request = builder
566 .build()
567 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
568
569 let mut stream = self
570 .client
571 .chat()
572 .create_stream(request)
573 .await
574 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
575
576 let result_stream = async_stream::stream! {
577 while let Some(result) = stream.next().await {
578 match result {
579 Ok(response) => {
580 for choice in response.choices {
581 if let Some(content) = choice.delta.content {
582 yield Ok(content);
583 }
584 }
585 }
586 Err(e) => {
587 yield Err(AppError::LLM(format!("Stream error: {}", e)));
588 }
589 }
590 }
591 };
592
593 Ok(Box::new(Box::pin(result_stream)))
594 }
595
596 async fn stream_with_system(
597 &self,
598 system: &str,
599 prompt: &str,
600 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
601 let system_message = ChatCompletionRequestSystemMessageArgs::default()
602 .content(system)
603 .build()
604 .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
605
606 let user_message = ChatCompletionRequestUserMessageArgs::default()
607 .content(prompt)
608 .build()
609 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
610
611 let mut builder = CreateChatCompletionRequestArgs::default();
612 builder.model(&self.model);
613 builder.messages(vec![
614 ChatCompletionRequestMessage::System(system_message),
615 ChatCompletionRequestMessage::User(user_message),
616 ]);
617
618 if let Some(temp) = self.params.temperature {
620 builder.temperature(temp);
621 }
622 if let Some(max_tokens) = self.params.max_tokens {
623 builder.max_completion_tokens(max_tokens);
624 }
625 if let Some(top_p) = self.params.top_p {
626 builder.top_p(top_p);
627 }
628 if let Some(freq_penalty) = self.params.frequency_penalty {
629 builder.frequency_penalty(freq_penalty);
630 }
631 if let Some(pres_penalty) = self.params.presence_penalty {
632 builder.presence_penalty(pres_penalty);
633 }
634
635 let request = builder
636 .build()
637 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
638
639 let mut stream = self
640 .client
641 .chat()
642 .create_stream(request)
643 .await
644 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
645
646 let result_stream = async_stream::stream! {
647 while let Some(result) = stream.next().await {
648 match result {
649 Ok(response) => {
650 for choice in response.choices {
651 if let Some(content) = choice.delta.content {
652 yield Ok(content);
653 }
654 }
655 }
656 Err(e) => {
657 yield Err(AppError::LLM(format!("Stream error: {}", e)));
658 }
659 }
660 }
661 };
662
663 Ok(Box::new(Box::pin(result_stream)))
664 }
665
666 async fn stream_with_history(
667 &self,
668 messages: &[(String, String)],
669 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
670 let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
671 messages
672 .iter()
673 .map(|(role, content)| {
674 match role.as_str() {
675 "system" => {
676 let msg = ChatCompletionRequestSystemMessageArgs::default()
677 .content(content.as_str())
678 .build()
679 .map_err(|e| {
680 AppError::LLM(format!("Failed to build system message: {}", e))
681 })?;
682 Ok(ChatCompletionRequestMessage::System(msg))
683 }
684 "assistant" => {
685 let msg = ChatCompletionRequestAssistantMessageArgs::default()
686 .content(content.as_str())
687 .build()
688 .map_err(|e| {
689 AppError::LLM(format!(
690 "Failed to build assistant message: {}",
691 e
692 ))
693 })?;
694 Ok(ChatCompletionRequestMessage::Assistant(msg))
695 }
696 _ => {
697 let msg = ChatCompletionRequestUserMessageArgs::default()
699 .content(content.as_str())
700 .build()
701 .map_err(|e| {
702 AppError::LLM(format!("Failed to build user message: {}", e))
703 })?;
704 Ok(ChatCompletionRequestMessage::User(msg))
705 }
706 }
707 })
708 .collect();
709
710 let mut builder = CreateChatCompletionRequestArgs::default();
711 builder.model(&self.model);
712 builder.messages(chat_messages?);
713
714 if let Some(temp) = self.params.temperature {
716 builder.temperature(temp);
717 }
718 if let Some(max_tokens) = self.params.max_tokens {
719 builder.max_completion_tokens(max_tokens);
720 }
721 if let Some(top_p) = self.params.top_p {
722 builder.top_p(top_p);
723 }
724 if let Some(freq_penalty) = self.params.frequency_penalty {
725 builder.frequency_penalty(freq_penalty);
726 }
727 if let Some(pres_penalty) = self.params.presence_penalty {
728 builder.presence_penalty(pres_penalty);
729 }
730
731 let request = builder
732 .build()
733 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
734
735 let mut stream = self
736 .client
737 .chat()
738 .create_stream(request)
739 .await
740 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
741
742 let result_stream = async_stream::stream! {
743 while let Some(result) = stream.next().await {
744 match result {
745 Ok(response) => {
746 for choice in response.choices {
747 if let Some(content) = choice.delta.content {
748 yield Ok(content);
749 }
750 }
751 }
752 Err(e) => {
753 yield Err(AppError::LLM(format!("Stream error: {}", e)));
754 }
755 }
756 }
757 };
758
759 Ok(Box::new(Box::pin(result_stream)))
760 }
761
762 fn model_name(&self) -> &str {
763 &self.model
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770
771 #[test]
772 fn test_client_creation() {
773 let client = OpenAIClient::new(
774 "test-key".to_string(),
775 "https://api.openai.com/v1".to_string(),
776 "gpt-4".to_string(),
777 );
778
779 assert_eq!(client.model_name(), "gpt-4");
780 }
781
782 #[test]
783 fn test_tool_conversion() {
784 let tool = ToolDefinition {
785 name: "calculator".to_string(),
786 description: "Performs math operations".to_string(),
787 parameters: serde_json::json!({
788 "type": "object",
789 "properties": {
790 "operation": {"type": "string"},
791 "a": {"type": "number"},
792 "b": {"type": "number"}
793 },
794 "required": ["operation", "a", "b"]
795 }),
796 };
797
798 let openai_tool = OpenAIClient::convert_tool(&tool);
799 match openai_tool {
800 ChatCompletionTools::Function(chat_tool) => {
801 assert_eq!(chat_tool.function.name, "calculator");
802 assert_eq!(
803 chat_tool.function.description,
804 Some("Performs math operations".to_string())
805 );
806 }
807 ChatCompletionTools::Custom(_) => {
808 panic!("Expected Function variant, got Custom");
809 }
810 }
811 }
812}