kalosm_language_model/openai/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::sync::OnceLock;

use thiserror::Error;

mod embedding;
pub use embedding::*;

mod chat;
pub use chat::*;

/// A client for making requests to an OpenAI compatible API.
#[derive(Debug, Clone)]
pub struct OpenAICompatibleClient {
    reqwest_client: reqwest::Client,
    base_url: String,
    api_key: Option<String>,
    resolved_api_key: OnceLock<String>,
    organization_id: Option<String>,
    project_id: Option<String>,
}

impl Default for OpenAICompatibleClient {
    fn default() -> Self {
        Self::new()
    }
}

impl OpenAICompatibleClient {
    /// Create a new client.
    pub fn new() -> Self {
        Self {
            reqwest_client: reqwest::Client::new(),
            base_url: "https://api.openai.com/v1/".to_string(),
            resolved_api_key: OnceLock::new(),
            api_key: None,
            organization_id: None,
            project_id: None,
        }
    }

    /// Sets the API key for the builder. (defaults to the environment variable `OPENAI_API_KEY`)
    ///
    /// The API key can be accessed from the OpenAI dashboard [here](https://platform.openai.com/settings/organization/api-keys).
    pub fn with_api_key(mut self, api_key: impl ToString) -> Self {
        self.api_key = Some(api_key.to_string());
        self
    }

    /// Set the base URL of the API. (defaults to `https://api.openai.com/v1/`)
    pub fn with_base_url(mut self, base_url: impl ToString) -> Self {
        self.base_url = base_url.to_string();
        self
    }

    /// Set the organization ID for the builder.
    ///
    /// The organization ID can be accessed from the OpenAI dashboard [here](https://platform.openai.com/settings/organization/general).
    pub fn with_organization_id(mut self, organization_id: impl ToString) -> Self {
        self.organization_id = Some(organization_id.to_string());
        self
    }

    /// Set the project ID for the builder.
    ///
    /// The project ID can be accessed from the OpenAI dashboard [here](https://platform.openai.com/settings/organization/projects).
    pub fn with_project_id(mut self, project_id: impl ToString) -> Self {
        self.project_id = Some(project_id.to_string());
        self
    }

    /// Set the reqwest client for the builder.
    pub fn with_reqwest_client(mut self, client: reqwest::Client) -> Self {
        self.reqwest_client = client;
        self
    }

    /// Resolve the openai API key from the environment variable `OPENAI_API_KEY` or the provided api key.
    pub fn resolve_api_key(&self) -> Result<String, NoOpenAIAPIKeyError> {
        if let Some(api_key) = self.resolved_api_key.get() {
            return Ok(api_key.clone());
        }

        let open_api_key = match self.api_key.clone() {
            Some(api_key) => api_key,
            None => std::env::var("OPENAI_API_KEY").map_err(|_| NoOpenAIAPIKeyError)?,
        };

        self.resolved_api_key.set(open_api_key.clone()).unwrap();

        Ok(open_api_key)
    }

    /// Get the base URL for the OpenAI API.
    pub(crate) fn base_url(&self) -> &str {
        self.base_url.trim_end_matches('/')
    }
}

/// An error that can occur when building a remote OpenAI model without an API key.
#[derive(Debug, Error)]
#[error("No API key was provided in the [OpenAICompatibleClient] builder or the environment variable `OPENAI_API_KEY` was not set")]
pub struct NoOpenAIAPIKeyError;