1use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_openai::{
27 config::OpenAIConfig,
28 types::chat::{
29 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
30 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
31 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
32 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
33 CreateChatCompletionRequestArgs, FunctionCall, FunctionObject,
34 },
35 Client,
36};
37use async_trait::async_trait;
38use futures::StreamExt;
39
40pub struct OpenAIClient {
42 client: Client<OpenAIConfig>,
43 model: String,
44 params: ModelParams,
45}
46
47impl OpenAIClient {
48 pub fn new(api_key: String, api_base: String, model: String) -> Self {
56 Self::with_params(api_key, api_base, model, ModelParams::default())
57 }
58
59 pub fn with_params(
68 api_key: String,
69 api_base: String,
70 model: String,
71 params: ModelParams,
72 ) -> Self {
73 let config = OpenAIConfig::new()
74 .with_api_key(api_key)
75 .with_api_base(api_base);
76
77 Self {
78 client: Client::with_config(config),
79 model,
80 params,
81 }
82 }
83
84 fn convert_tool(tool: &ToolDefinition) -> ChatCompletionTools {
86 ChatCompletionTools::Function(ChatCompletionTool {
87 function: FunctionObject {
88 name: tool.name.clone(),
89 description: Some(tool.description.clone()),
90 parameters: Some(tool.parameters.clone()),
91 strict: None,
92 },
93 })
94 }
95
96 fn extract_tool_calls(tool_calls: &[ChatCompletionMessageToolCalls]) -> Vec<ToolCall> {
98 tool_calls
99 .iter()
100 .filter_map(|wrapper| match wrapper {
101 ChatCompletionMessageToolCalls::Function(call) => Some(ToolCall {
102 id: call.id.clone(),
103 name: call.function.name.clone(),
104 arguments: serde_json::from_str(&call.function.arguments)
105 .unwrap_or(serde_json::json!({})),
106 }),
107 ChatCompletionMessageToolCalls::Custom(_) => None,
108 })
109 .collect()
110 }
111
112 fn convert_conversation_message(
114 &self,
115 msg: &ConversationMessage,
116 ) -> Result<ChatCompletionRequestMessage> {
117 match msg.role {
118 MessageRole::System => {
119 let system_msg = ChatCompletionRequestSystemMessageArgs::default()
120 .content(msg.content.clone())
121 .build()
122 .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
123 Ok(ChatCompletionRequestMessage::System(system_msg))
124 }
125 MessageRole::User => {
126 let user_msg = ChatCompletionRequestUserMessageArgs::default()
127 .content(msg.content.clone())
128 .build()
129 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
130 Ok(ChatCompletionRequestMessage::User(user_msg))
131 }
132 MessageRole::Assistant => {
133 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
134
135 if !msg.content.is_empty() {
136 builder.content(msg.content.clone());
137 }
138
139 if !msg.tool_calls.is_empty() {
141 let openai_tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
142 .tool_calls
143 .iter()
144 .map(|tc| {
145 ChatCompletionMessageToolCalls::Function(
146 ChatCompletionMessageToolCall {
147 id: tc.id.clone(),
148 function: FunctionCall {
149 name: tc.name.clone(),
150 arguments: serde_json::to_string(&tc.arguments)
151 .unwrap_or_else(|_| "{}".to_string()),
152 },
153 },
154 )
155 })
156 .collect();
157 builder.tool_calls(openai_tool_calls);
158 }
159
160 let assistant_msg = builder.build().map_err(|e| {
161 AppError::LLM(format!("Failed to build assistant message: {}", e))
162 })?;
163 Ok(ChatCompletionRequestMessage::Assistant(assistant_msg))
164 }
165 MessageRole::Tool => {
166 let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
167 AppError::LLM("Tool message must have a tool_call_id".to_string())
168 })?;
169
170 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
171 .tool_call_id(tool_call_id)
172 .content(msg.content.clone())
173 .build()
174 .map_err(|e| AppError::LLM(format!("Failed to build tool message: {}", e)))?;
175 Ok(ChatCompletionRequestMessage::Tool(tool_msg))
176 }
177 }
178 }
179}
180
181#[async_trait]
182impl LLMClient for OpenAIClient {
183 async fn generate(&self, prompt: &str) -> Result<String> {
184 let message = ChatCompletionRequestUserMessageArgs::default()
185 .content(prompt)
186 .build()
187 .map_err(|e| AppError::LLM(format!("Failed to build message: {}", e)))?;
188
189 let mut builder = CreateChatCompletionRequestArgs::default();
190 builder.model(&self.model);
191 builder.messages(vec![ChatCompletionRequestMessage::User(message)]);
192
193 if let Some(temp) = self.params.temperature {
195 builder.temperature(temp);
196 }
197 if let Some(max_tokens) = self.params.max_tokens {
198 builder.max_completion_tokens(max_tokens);
199 }
200 if let Some(top_p) = self.params.top_p {
201 builder.top_p(top_p);
202 }
203 if let Some(freq_penalty) = self.params.frequency_penalty {
204 builder.frequency_penalty(freq_penalty);
205 }
206 if let Some(pres_penalty) = self.params.presence_penalty {
207 builder.presence_penalty(pres_penalty);
208 }
209
210 let request = builder
211 .build()
212 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
213
214 let response = self
215 .client
216 .chat()
217 .create(request)
218 .await
219 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
220
221 response
222 .choices
223 .first()
224 .and_then(|choice| choice.message.content.clone())
225 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
226 }
227
228 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
229 let system_message = ChatCompletionRequestSystemMessageArgs::default()
230 .content(system)
231 .build()
232 .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
233
234 let user_message = ChatCompletionRequestUserMessageArgs::default()
235 .content(prompt)
236 .build()
237 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
238
239 let mut builder = CreateChatCompletionRequestArgs::default();
240 builder.model(&self.model);
241 builder.messages(vec![
242 ChatCompletionRequestMessage::System(system_message),
243 ChatCompletionRequestMessage::User(user_message),
244 ]);
245
246 if let Some(temp) = self.params.temperature {
248 builder.temperature(temp);
249 }
250 if let Some(max_tokens) = self.params.max_tokens {
251 builder.max_completion_tokens(max_tokens);
252 }
253 if let Some(top_p) = self.params.top_p {
254 builder.top_p(top_p);
255 }
256 if let Some(freq_penalty) = self.params.frequency_penalty {
257 builder.frequency_penalty(freq_penalty);
258 }
259 if let Some(pres_penalty) = self.params.presence_penalty {
260 builder.presence_penalty(pres_penalty);
261 }
262
263 let request = builder
264 .build()
265 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
266
267 let response = self
268 .client
269 .chat()
270 .create(request)
271 .await
272 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
273
274 response
275 .choices
276 .first()
277 .and_then(|choice| choice.message.content.clone())
278 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
279 }
280
281 async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<String> {
282 let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
283 messages
284 .iter()
285 .map(|(role, content)| {
286 match role.as_str() {
287 "system" => {
288 let msg = ChatCompletionRequestSystemMessageArgs::default()
289 .content(content.as_str())
290 .build()
291 .map_err(|e| {
292 AppError::LLM(format!("Failed to build system message: {}", e))
293 })?;
294 Ok(ChatCompletionRequestMessage::System(msg))
295 }
296 "assistant" => {
297 let msg = ChatCompletionRequestAssistantMessageArgs::default()
298 .content(content.as_str())
299 .build()
300 .map_err(|e| {
301 AppError::LLM(format!(
302 "Failed to build assistant message: {}",
303 e
304 ))
305 })?;
306 Ok(ChatCompletionRequestMessage::Assistant(msg))
307 }
308 _ => {
309 let msg = ChatCompletionRequestUserMessageArgs::default()
311 .content(content.as_str())
312 .build()
313 .map_err(|e| {
314 AppError::LLM(format!("Failed to build user message: {}", e))
315 })?;
316 Ok(ChatCompletionRequestMessage::User(msg))
317 }
318 }
319 })
320 .collect();
321
322 let mut builder = CreateChatCompletionRequestArgs::default();
323 builder.model(&self.model);
324 builder.messages(chat_messages?);
325
326 if let Some(temp) = self.params.temperature {
328 builder.temperature(temp);
329 }
330 if let Some(max_tokens) = self.params.max_tokens {
331 builder.max_completion_tokens(max_tokens);
332 }
333 if let Some(top_p) = self.params.top_p {
334 builder.top_p(top_p);
335 }
336 if let Some(freq_penalty) = self.params.frequency_penalty {
337 builder.frequency_penalty(freq_penalty);
338 }
339 if let Some(pres_penalty) = self.params.presence_penalty {
340 builder.presence_penalty(pres_penalty);
341 }
342
343 let request = builder
344 .build()
345 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
346
347 let response = self
348 .client
349 .chat()
350 .create(request)
351 .await
352 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
353
354 response
355 .choices
356 .first()
357 .and_then(|choice| choice.message.content.clone())
358 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
359 }
360
361 async fn generate_with_tools(
362 &self,
363 prompt: &str,
364 tools: &[ToolDefinition],
365 ) -> Result<LLMResponse> {
366 let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
367
368 let user_message = ChatCompletionRequestUserMessageArgs::default()
369 .content(prompt)
370 .build()
371 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
372
373 let mut builder = CreateChatCompletionRequestArgs::default();
374 builder.model(&self.model);
375 builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
376 builder.tools(openai_tools);
377
378 if let Some(temp) = self.params.temperature {
380 builder.temperature(temp);
381 }
382 if let Some(max_tokens) = self.params.max_tokens {
383 builder.max_completion_tokens(max_tokens);
384 }
385 if let Some(top_p) = self.params.top_p {
386 builder.top_p(top_p);
387 }
388 if let Some(freq_penalty) = self.params.frequency_penalty {
389 builder.frequency_penalty(freq_penalty);
390 }
391 if let Some(pres_penalty) = self.params.presence_penalty {
392 builder.presence_penalty(pres_penalty);
393 }
394
395 let request = builder
396 .build()
397 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
398
399 let response = self
400 .client
401 .chat()
402 .create(request)
403 .await
404 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
405
406 let choice = response
407 .choices
408 .first()
409 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
410
411 let content = choice.message.content.clone().unwrap_or_default();
412
413 let finish_reason = choice
414 .finish_reason
415 .as_ref()
416 .map(|r| format!("{:?}", r).to_lowercase())
417 .unwrap_or_else(|| "stop".to_string());
418
419 let tool_calls = choice
420 .message
421 .tool_calls
422 .as_ref()
423 .map(|calls| Self::extract_tool_calls(calls))
424 .unwrap_or_default();
425
426 #[allow(clippy::unnecessary_cast)]
428 let usage = response
429 .usage
430 .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
431
432 Ok(LLMResponse {
433 content,
434 tool_calls,
435 finish_reason,
436 usage,
437 })
438 }
439
440 async fn generate_with_tools_and_history(
441 &self,
442 messages: &[ConversationMessage],
443 tools: &[ToolDefinition],
444 ) -> Result<LLMResponse> {
445 let openai_messages: Vec<ChatCompletionRequestMessage> = messages
447 .iter()
448 .map(|msg| self.convert_conversation_message(msg))
449 .collect::<Result<Vec<_>>>()?;
450
451 let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
453
454 let mut builder = CreateChatCompletionRequestArgs::default();
455 builder.model(&self.model);
456 builder.messages(openai_messages);
457
458 if !openai_tools.is_empty() {
459 builder.tools(openai_tools);
460 }
461
462 if let Some(temp) = self.params.temperature {
464 builder.temperature(temp);
465 }
466 if let Some(max_tokens) = self.params.max_tokens {
467 builder.max_completion_tokens(max_tokens);
468 }
469 if let Some(top_p) = self.params.top_p {
470 builder.top_p(top_p);
471 }
472 if let Some(freq_penalty) = self.params.frequency_penalty {
473 builder.frequency_penalty(freq_penalty);
474 }
475 if let Some(pres_penalty) = self.params.presence_penalty {
476 builder.presence_penalty(pres_penalty);
477 }
478
479 let request = builder
480 .build()
481 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
482
483 let response = self
484 .client
485 .chat()
486 .create(request)
487 .await
488 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
489
490 let choice = response
491 .choices
492 .first()
493 .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
494
495 let content = choice.message.content.clone().unwrap_or_default();
496
497 let finish_reason = choice
498 .finish_reason
499 .as_ref()
500 .map(|r| format!("{:?}", r).to_lowercase())
501 .unwrap_or_else(|| "stop".to_string());
502
503 let tool_calls = choice
504 .message
505 .tool_calls
506 .as_ref()
507 .map(|calls| Self::extract_tool_calls(calls))
508 .unwrap_or_default();
509
510 #[allow(clippy::unnecessary_cast)]
511 let usage = response
512 .usage
513 .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
514
515 Ok(LLMResponse {
516 content,
517 tool_calls,
518 finish_reason,
519 usage,
520 })
521 }
522
523 async fn stream(
524 &self,
525 prompt: &str,
526 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
527 let user_message = ChatCompletionRequestUserMessageArgs::default()
528 .content(prompt)
529 .build()
530 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
531
532 let mut builder = CreateChatCompletionRequestArgs::default();
533 builder.model(&self.model);
534 builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
535
536 if let Some(temp) = self.params.temperature {
538 builder.temperature(temp);
539 }
540 if let Some(max_tokens) = self.params.max_tokens {
541 builder.max_completion_tokens(max_tokens);
542 }
543 if let Some(top_p) = self.params.top_p {
544 builder.top_p(top_p);
545 }
546 if let Some(freq_penalty) = self.params.frequency_penalty {
547 builder.frequency_penalty(freq_penalty);
548 }
549 if let Some(pres_penalty) = self.params.presence_penalty {
550 builder.presence_penalty(pres_penalty);
551 }
552
553 let request = builder
554 .build()
555 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
556
557 let mut stream = self
558 .client
559 .chat()
560 .create_stream(request)
561 .await
562 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
563
564 let result_stream = async_stream::stream! {
565 while let Some(result) = stream.next().await {
566 match result {
567 Ok(response) => {
568 for choice in response.choices {
569 if let Some(content) = choice.delta.content {
570 yield Ok(content);
571 }
572 }
573 }
574 Err(e) => {
575 yield Err(AppError::LLM(format!("Stream error: {}", e)));
576 }
577 }
578 }
579 };
580
581 Ok(Box::new(Box::pin(result_stream)))
582 }
583
584 async fn stream_with_system(
585 &self,
586 system: &str,
587 prompt: &str,
588 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
589 let system_message = ChatCompletionRequestSystemMessageArgs::default()
590 .content(system)
591 .build()
592 .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
593
594 let user_message = ChatCompletionRequestUserMessageArgs::default()
595 .content(prompt)
596 .build()
597 .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
598
599 let mut builder = CreateChatCompletionRequestArgs::default();
600 builder.model(&self.model);
601 builder.messages(vec![
602 ChatCompletionRequestMessage::System(system_message),
603 ChatCompletionRequestMessage::User(user_message),
604 ]);
605
606 if let Some(temp) = self.params.temperature {
608 builder.temperature(temp);
609 }
610 if let Some(max_tokens) = self.params.max_tokens {
611 builder.max_completion_tokens(max_tokens);
612 }
613 if let Some(top_p) = self.params.top_p {
614 builder.top_p(top_p);
615 }
616 if let Some(freq_penalty) = self.params.frequency_penalty {
617 builder.frequency_penalty(freq_penalty);
618 }
619 if let Some(pres_penalty) = self.params.presence_penalty {
620 builder.presence_penalty(pres_penalty);
621 }
622
623 let request = builder
624 .build()
625 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
626
627 let mut stream = self
628 .client
629 .chat()
630 .create_stream(request)
631 .await
632 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
633
634 let result_stream = async_stream::stream! {
635 while let Some(result) = stream.next().await {
636 match result {
637 Ok(response) => {
638 for choice in response.choices {
639 if let Some(content) = choice.delta.content {
640 yield Ok(content);
641 }
642 }
643 }
644 Err(e) => {
645 yield Err(AppError::LLM(format!("Stream error: {}", e)));
646 }
647 }
648 }
649 };
650
651 Ok(Box::new(Box::pin(result_stream)))
652 }
653
654 async fn stream_with_history(
655 &self,
656 messages: &[(String, String)],
657 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
658 let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
659 messages
660 .iter()
661 .map(|(role, content)| {
662 match role.as_str() {
663 "system" => {
664 let msg = ChatCompletionRequestSystemMessageArgs::default()
665 .content(content.as_str())
666 .build()
667 .map_err(|e| {
668 AppError::LLM(format!("Failed to build system message: {}", e))
669 })?;
670 Ok(ChatCompletionRequestMessage::System(msg))
671 }
672 "assistant" => {
673 let msg = ChatCompletionRequestAssistantMessageArgs::default()
674 .content(content.as_str())
675 .build()
676 .map_err(|e| {
677 AppError::LLM(format!(
678 "Failed to build assistant message: {}",
679 e
680 ))
681 })?;
682 Ok(ChatCompletionRequestMessage::Assistant(msg))
683 }
684 _ => {
685 let msg = ChatCompletionRequestUserMessageArgs::default()
687 .content(content.as_str())
688 .build()
689 .map_err(|e| {
690 AppError::LLM(format!("Failed to build user message: {}", e))
691 })?;
692 Ok(ChatCompletionRequestMessage::User(msg))
693 }
694 }
695 })
696 .collect();
697
698 let mut builder = CreateChatCompletionRequestArgs::default();
699 builder.model(&self.model);
700 builder.messages(chat_messages?);
701
702 if let Some(temp) = self.params.temperature {
704 builder.temperature(temp);
705 }
706 if let Some(max_tokens) = self.params.max_tokens {
707 builder.max_completion_tokens(max_tokens);
708 }
709 if let Some(top_p) = self.params.top_p {
710 builder.top_p(top_p);
711 }
712 if let Some(freq_penalty) = self.params.frequency_penalty {
713 builder.frequency_penalty(freq_penalty);
714 }
715 if let Some(pres_penalty) = self.params.presence_penalty {
716 builder.presence_penalty(pres_penalty);
717 }
718
719 let request = builder
720 .build()
721 .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
722
723 let mut stream = self
724 .client
725 .chat()
726 .create_stream(request)
727 .await
728 .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
729
730 let result_stream = async_stream::stream! {
731 while let Some(result) = stream.next().await {
732 match result {
733 Ok(response) => {
734 for choice in response.choices {
735 if let Some(content) = choice.delta.content {
736 yield Ok(content);
737 }
738 }
739 }
740 Err(e) => {
741 yield Err(AppError::LLM(format!("Stream error: {}", e)));
742 }
743 }
744 }
745 };
746
747 Ok(Box::new(Box::pin(result_stream)))
748 }
749
750 fn model_name(&self) -> &str {
751 &self.model
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::*;
758
759 #[test]
760 fn test_client_creation() {
761 let client = OpenAIClient::new(
762 "test-key".to_string(),
763 "https://api.openai.com/v1".to_string(),
764 "gpt-4".to_string(),
765 );
766
767 assert_eq!(client.model_name(), "gpt-4");
768 }
769
770 #[test]
771 fn test_tool_conversion() {
772 let tool = ToolDefinition {
773 name: "calculator".to_string(),
774 description: "Performs math operations".to_string(),
775 parameters: serde_json::json!({
776 "type": "object",
777 "properties": {
778 "operation": {"type": "string"},
779 "a": {"type": "number"},
780 "b": {"type": "number"}
781 },
782 "required": ["operation", "a", "b"]
783 }),
784 };
785
786 let openai_tool = OpenAIClient::convert_tool(&tool);
787 match openai_tool {
788 ChatCompletionTools::Function(chat_tool) => {
789 assert_eq!(chat_tool.function.name, "calculator");
790 assert_eq!(
791 chat_tool.function.description,
792 Some("Performs math operations".to_string())
793 );
794 }
795 ChatCompletionTools::Custom(_) => {
796 panic!("Expected Function variant, got Custom");
797 }
798 }
799 }
800}