potato_agent/agents/
client.rs

1use crate::agents::error::AgentError;
2use crate::agents::provider::gemini::GeminiClient;
3use crate::agents::provider::openai::OpenAIClient;
4use crate::agents::types::ChatResponse;
5use potato_prompt::Prompt;
6use potato_type::Provider;
7use pyo3::prelude::*;
8use reqwest::header::HeaderName;
9use reqwest::header::{HeaderMap, HeaderValue};
10use reqwest::Client;
11use std::collections::HashMap;
12use std::str::FromStr;
13use tracing::{error, instrument};
14const TIMEOUT_SECS: u64 = 30;
15
16#[derive(Debug, Clone)]
17#[pyclass]
18pub enum ClientType {
19    OpenAI,
20}
21
22pub enum ClientUrl {
23    OpenAI,
24}
25
26impl ClientUrl {
27    pub fn url(&self) -> &str {
28        match self {
29            ClientUrl::OpenAI => "https://api.openai.com",
30        }
31    }
32}
33
34/// Create the blocking HTTP client with optional headers.
35pub fn build_http_client(
36    client_headers: Option<HashMap<String, String>>,
37) -> Result<Client, AgentError> {
38    let mut headers = HeaderMap::new();
39
40    if let Some(headers_map) = client_headers {
41        for (key, value) in headers_map {
42            headers.insert(
43                HeaderName::from_str(&key).map_err(AgentError::CreateHeaderNameError)?,
44                HeaderValue::from_str(&value).map_err(AgentError::CreateHeaderValueError)?,
45            );
46        }
47    }
48
49    let client_builder = Client::builder().timeout(std::time::Duration::from_secs(TIMEOUT_SECS));
50
51    let client = client_builder
52        .default_headers(headers)
53        .build()
54        .map_err(AgentError::CreateClientError)?;
55
56    Ok(client)
57}
58
59#[derive(Debug, Clone, PartialEq)]
60pub enum GenAiClient {
61    OpenAI(OpenAIClient),
62    Gemini(GeminiClient),
63}
64
65impl GenAiClient {
66    #[instrument(skip_all)]
67    pub async fn execute(&self, task: &Prompt) -> Result<ChatResponse, AgentError> {
68        match self {
69            GenAiClient::OpenAI(client) => {
70                let response = client.async_chat_completion(task).await.inspect_err(|e| {
71                    error!(error = %e, "Failed to complete chat");
72                })?;
73                Ok(ChatResponse::OpenAI(response))
74            }
75            GenAiClient::Gemini(client) => {
76                let response = client.async_generate_content(task).await.inspect_err(|e| {
77                    error!(error = %e, "Failed to generate content");
78                })?;
79                Ok(ChatResponse::Gemini(response))
80            }
81        }
82    }
83
84    pub fn provider(&self) -> &Provider {
85        match self {
86            GenAiClient::OpenAI(client) => &client.provider,
87            GenAiClient::Gemini(client) => &client.provider,
88        }
89    }
90}