llm_api_rs/providers/
openai.rs

1// src/providers/openai.rs
2// https://platform.openai.com/docs/api-reference/chat/create
3// https://platform.openai.com
4
5use crate::core::client::APIClient;
6use crate::core::{
7    ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatUsage,
8};
9use crate::error::LlmApiError;
10use async_trait::async_trait;
11use reqwest::header;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize)]
15pub struct OpenAIChatCompletionRequest {
16    pub model: String,
17    pub messages: Vec<ChatMessage>,
18    pub temperature: Option<f32>,
19    pub max_tokens: Option<u32>,
20}
21
22#[derive(Debug, Deserialize)]
23pub struct OpenAIChatCompletionResponse {
24    pub id: String,
25    pub model: String,
26    pub choices: Vec<ChatChoice>,
27    pub usage: Option<ChatUsage>,
28}
29
30pub struct OpenAI {
31    domain: String,
32    api_key: String,
33    client: APIClient,
34}
35
36impl OpenAI {
37    pub fn new(api_key: String) -> Self {
38        OpenAI {
39            domain: "https://api.openai.com".to_string(),
40            api_key,
41            client: APIClient::new(),
42        }
43    }
44}
45
46#[async_trait]
47impl crate::providers::LlmProvider for OpenAI {
48    async fn chat_completion(
49        &self,
50        request: ChatCompletionRequest,
51    ) -> Result<ChatCompletionResponse, LlmApiError> {
52        let url = format!("{}/v1/chat/completions", self.domain);
53
54        let headers = vec![(header::AUTHORIZATION, format!("Bearer {}", self.api_key))];
55
56        let req = OpenAIChatCompletionRequest {
57            model: request.model,
58            messages: request.messages,
59            temperature: request.temperature,
60            max_tokens: request.max_tokens,
61        };
62
63        let res: OpenAIChatCompletionResponse =
64            self.client.send_request(url, headers, &req).await?;
65
66        Ok(ChatCompletionResponse {
67            id: res.id,
68            choices: res.choices,
69            model: res.model,
70            usage: res.usage,
71        })
72    }
73}