Skip to main content

agent_runtime/llm/provider/
openai.rs

1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4
5use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult, TextStream};
6
7const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
8
9/// OpenAI chat client
10pub struct OpenAIClient {
11    api_key: String,
12    model: String,
13    http_client: HttpClient,
14}
15
16impl OpenAIClient {
17    /// Create a new OpenAI client
18    pub fn new(api_key: impl Into<String>) -> Self {
19        Self::with_model(api_key, "gpt-4")
20    }
21
22    /// Create a new OpenAI client with specific model
23    pub fn with_model(api_key: impl Into<String>, model: impl Into<String>) -> Self {
24        Self {
25            api_key: api_key.into(),
26            model: model.into(),
27            http_client: HttpClient::new(),
28        }
29    }
30}
31
32#[async_trait]
33impl ChatClient for OpenAIClient {
34    async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
35        // Build OpenAI API request
36        let openai_request = OpenAIChatRequest {
37            model: self.model.clone(),
38            messages: request.messages,
39            temperature: request.temperature,
40            max_tokens: request.max_tokens,
41            top_p: request.top_p,
42        };
43
44        // Send request
45        let response = self
46            .http_client
47            .post(OPENAI_API_URL)
48            .header("Authorization", format!("Bearer {}", self.api_key))
49            .header("Content-Type", "application/json")
50            .json(&openai_request)
51            .send()
52            .await
53            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
54
55        // Check status
56        let status = response.status();
57        if !status.is_success() {
58            let error_text = response.text().await.unwrap_or_default();
59            return Err(match status.as_u16() {
60                401 => LlmError::AuthenticationFailed(error_text),
61                429 => LlmError::RateLimitExceeded,
62                _ => LlmError::ApiError(format!("Status {}: {}", status, error_text)),
63            });
64        }
65
66        // Parse response
67        let openai_response: OpenAIChatResponse = response
68            .json()
69            .await
70            .map_err(|e| LlmError::ParseError(e.to_string()))?;
71
72        // Extract first choice
73        let choice = openai_response
74            .choices
75            .first()
76            .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
77
78        Ok(ChatResponse {
79            content: choice.message.content.clone(),
80            model: openai_response.model,
81            usage: openai_response.usage.map(|u| super::super::types::Usage {
82                prompt_tokens: u.prompt_tokens,
83                completion_tokens: u.completion_tokens,
84                total_tokens: u.total_tokens,
85            }),
86            finish_reason: choice.finish_reason.clone(),
87        })
88    }
89
90    async fn chat_stream(&self, _request: ChatRequest) -> LlmResult<TextStream> {
91        // Simple non-streaming fallback for OpenAI - full implementation would use SSE
92        // For now, return error suggesting to use llama.cpp for streaming
93        Err(LlmError::ApiError(
94            "Streaming not yet implemented for OpenAI - use LlamaClient".to_string(),
95        ))
96    }
97
98    fn model(&self) -> &str {
99        &self.model
100    }
101
102    fn provider(&self) -> &str {
103        "openai"
104    }
105}
106
107// OpenAI-specific request/response types
108
109#[derive(Debug, Serialize)]
110struct OpenAIChatRequest {
111    model: String,
112    messages: Vec<super::super::types::ChatMessage>,
113
114    #[serde(skip_serializing_if = "Option::is_none")]
115    temperature: Option<f32>,
116
117    #[serde(skip_serializing_if = "Option::is_none")]
118    max_tokens: Option<u32>,
119
120    #[serde(skip_serializing_if = "Option::is_none")]
121    top_p: Option<f32>,
122}
123
124#[derive(Debug, Deserialize)]
125struct OpenAIChatResponse {
126    model: String,
127    choices: Vec<Choice>,
128    usage: Option<UsageInfo>,
129}
130
131#[derive(Debug, Deserialize)]
132struct Choice {
133    message: Message,
134    finish_reason: Option<String>,
135}
136
137#[derive(Debug, Deserialize)]
138struct Message {
139    content: String,
140}
141
142#[derive(Debug, Deserialize)]
143struct UsageInfo {
144    prompt_tokens: u32,
145    completion_tokens: u32,
146    total_tokens: u32,
147}