potato_agent/agents/
client.rs1use crate::agents::error::AgentError;
2use crate::agents::provider::openai::OpenAIClient;
3use crate::agents::provider::types::Provider;
4use crate::agents::types::ChatResponse;
5use potato_prompt::Prompt;
6use pyo3::prelude::*;
7use reqwest::header::HeaderName;
8use reqwest::header::{HeaderMap, HeaderValue};
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::str::FromStr;
13const TIMEOUT_SECS: u64 = 30;
14
15#[derive(Debug, Clone)]
16#[pyclass]
17pub enum ClientType {
18 OpenAI,
19}
20
21pub enum ClientUrl {
22 OpenAI,
23}
24
25impl ClientUrl {
26 pub fn url(&self) -> &str {
27 match self {
28 ClientUrl::OpenAI => "https://api.openai.com",
29 }
30 }
31}
32
33pub fn build_http_client(
35 client_headers: Option<HashMap<String, String>>,
36) -> Result<Client, AgentError> {
37 let mut headers = HeaderMap::new();
38
39 if let Some(headers_map) = client_headers {
40 for (key, value) in headers_map {
41 headers.insert(
42 HeaderName::from_str(&key).map_err(AgentError::CreateHeaderNameError)?,
43 HeaderValue::from_str(&value).map_err(AgentError::CreateHeaderValueError)?,
44 );
45 }
46 }
47
48 let client_builder = Client::builder().timeout(std::time::Duration::from_secs(TIMEOUT_SECS));
49
50 let client = client_builder
51 .default_headers(headers)
52 .build()
53 .map_err(AgentError::CreateClientError)?;
54
55 Ok(client)
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum GenAiClient {
60 OpenAI(OpenAIClient),
61}
62
63impl GenAiClient {
64 pub async fn execute(&self, task: &Prompt) -> Result<ChatResponse, AgentError> {
65 match self {
66 GenAiClient::OpenAI(client) => {
67 let response = client.async_chat_completion(task).await?;
68 Ok(ChatResponse::OpenAI(response))
69 }
70 }
71 }
72
73 pub fn provider(&self) -> &Provider {
74 match self {
75 GenAiClient::OpenAI(client) => &client.provider,
76 }
77 }
78}