potato_agent/agents/
client.rs

1use crate::agents::error::AgentError;
2use crate::agents::provider::openai::OpenAIClient;
3use crate::agents::provider::types::Provider;
4use crate::agents::types::ChatResponse;
5use potato_prompt::Prompt;
6use pyo3::prelude::*;
7use reqwest::header::HeaderName;
8use reqwest::header::{HeaderMap, HeaderValue};
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::str::FromStr;
13const TIMEOUT_SECS: u64 = 30;
14
15#[derive(Debug, Clone)]
16#[pyclass]
17pub enum ClientType {
18    OpenAI,
19}
20
21pub enum ClientUrl {
22    OpenAI,
23}
24
25impl ClientUrl {
26    pub fn url(&self) -> &str {
27        match self {
28            ClientUrl::OpenAI => "https://api.openai.com",
29        }
30    }
31}
32
33/// Create the blocking HTTP client with optional headers.
34pub fn build_http_client(
35    client_headers: Option<HashMap<String, String>>,
36) -> Result<Client, AgentError> {
37    let mut headers = HeaderMap::new();
38
39    if let Some(headers_map) = client_headers {
40        for (key, value) in headers_map {
41            headers.insert(
42                HeaderName::from_str(&key).map_err(AgentError::CreateHeaderNameError)?,
43                HeaderValue::from_str(&value).map_err(AgentError::CreateHeaderValueError)?,
44            );
45        }
46    }
47
48    let client_builder = Client::builder().timeout(std::time::Duration::from_secs(TIMEOUT_SECS));
49
50    let client = client_builder
51        .default_headers(headers)
52        .build()
53        .map_err(AgentError::CreateClientError)?;
54
55    Ok(client)
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum GenAiClient {
60    OpenAI(OpenAIClient),
61}
62
63impl GenAiClient {
64    pub async fn execute(&self, task: &Prompt) -> Result<ChatResponse, AgentError> {
65        match self {
66            GenAiClient::OpenAI(client) => {
67                let response = client.async_chat_completion(task).await?;
68                Ok(ChatResponse::OpenAI(response))
69            }
70        }
71    }
72
73    pub fn provider(&self) -> &Provider {
74        match self {
75            GenAiClient::OpenAI(client) => &client.provider,
76        }
77    }
78}