Skip to main content

rig/providers/together/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, Nothing, Provider, ProviderBuilder, ProviderClient,
4    },
5    http_client,
6};
7
8// ================================================================
9// Together AI Client
10// ================================================================
11const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct TogetherExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct TogetherExtBuilder;
17
18type TogetherApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<TogetherExt, H>;
21pub type ClientBuilder<H = reqwest::Client> =
22    client::ClientBuilder<TogetherExtBuilder, TogetherApiKey, H>;
23
24impl Provider for TogetherExt {
25    type Builder = TogetherExtBuilder;
26
27    const VERIFY_PATH: &'static str = "/models";
28}
29
30impl<H> Capabilities<H> for TogetherExt {
31    type Completion = Capable<super::CompletionModel<H>>;
32    type Embeddings = Capable<super::EmbeddingModel<H>>;
33
34    type Transcription = Nothing;
35    type ModelListing = Nothing;
36    #[cfg(feature = "image")]
37    type ImageGeneration = Nothing;
38    #[cfg(feature = "audio")]
39    type AudioGeneration = Nothing;
40}
41
42impl ProviderBuilder for TogetherExtBuilder {
43    type Extension<H>
44        = TogetherExt
45    where
46        H: http_client::HttpClientExt;
47    type ApiKey = TogetherApiKey;
48
49    const BASE_URL: &'static str = TOGETHER_AI_BASE_URL;
50
51    fn build<H>(
52        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
53    ) -> http_client::Result<Self::Extension<H>>
54    where
55        H: http_client::HttpClientExt,
56    {
57        Ok(TogetherExt)
58    }
59}
60
61impl ProviderClient for Client {
62    type Input = String;
63
64    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
65    /// Panics if the environment variable is not set.
66    fn from_env() -> Self {
67        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
68        Self::new(&api_key).unwrap()
69    }
70
71    fn from_val(input: Self::Input) -> Self {
72        Self::new(&input).unwrap()
73    }
74}
75
76pub mod together_ai_api_types {
77    use serde::Deserialize;
78
79    impl ApiErrorResponse {
80        pub fn message(&self) -> String {
81            format!("Code `{}`: {}", self.code, self.error)
82        }
83    }
84
85    #[derive(Debug, Deserialize)]
86    pub struct ApiErrorResponse {
87        pub error: String,
88        pub code: String,
89    }
90
91    #[derive(Debug, Deserialize)]
92    #[serde(untagged)]
93    pub enum ApiResponse<T> {
94        Ok(T),
95        Error(ApiErrorResponse),
96    }
97}
98#[cfg(test)]
99mod tests {
100    #[test]
101    fn test_client_initialization() {
102        let _client =
103            crate::providers::together::Client::new("dummy-key").expect("Client::new() failed");
104        let _client_from_builder = crate::providers::together::Client::builder()
105            .api_key("dummy-key")
106            .build()
107            .expect("Client::builder() failed");
108    }
109}