rig_core/providers/together/
client.rs1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, Nothing, Provider, ProviderBuilder, ProviderClient,
4 },
5 http_client,
6};
7
8const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct TogetherExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct TogetherExtBuilder;
17
18type TogetherApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<TogetherExt, H>;
21pub type ClientBuilder<H = crate::markers::Missing> =
22 client::ClientBuilder<TogetherExtBuilder, TogetherApiKey, H>;
23
24impl Provider for TogetherExt {
25 type Builder = TogetherExtBuilder;
26
27 const VERIFY_PATH: &'static str = "/models";
28}
29
30impl<H> Capabilities<H> for TogetherExt {
31 type Completion = Capable<super::CompletionModel<H>>;
32 type Embeddings = Capable<super::EmbeddingModel<H>>;
33
34 type Transcription = Nothing;
35 type ModelListing = Nothing;
36 #[cfg(feature = "image")]
37 type ImageGeneration = Nothing;
38 #[cfg(feature = "audio")]
39 type AudioGeneration = Nothing;
40 type Rerank = Nothing;
41}
42
43impl ProviderBuilder for TogetherExtBuilder {
44 type Extension<H>
45 = TogetherExt
46 where
47 H: http_client::HttpClientExt;
48 type ApiKey = TogetherApiKey;
49
50 const BASE_URL: &'static str = TOGETHER_AI_BASE_URL;
51
52 fn build<H>(
53 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
54 ) -> http_client::Result<Self::Extension<H>>
55 where
56 H: http_client::HttpClientExt,
57 {
58 Ok(TogetherExt)
59 }
60}
61
62impl ProviderClient for Client {
63 type Input = String;
64 type Error = crate::client::ProviderClientError;
65
66 fn from_env() -> Result<Self, Self::Error> {
68 let api_key = crate::client::required_env_var("TOGETHER_API_KEY")?;
69 Self::new(&api_key).map_err(Into::into)
70 }
71
72 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
73 Self::new(&input).map_err(Into::into)
74 }
75}
76
77pub mod together_ai_api_types {
78 use serde::Deserialize;
79
80 impl ApiErrorResponse {
81 pub fn message(&self) -> String {
82 format!("Code `{}`: {}", self.code, self.error)
83 }
84 }
85
86 #[derive(Debug, Deserialize)]
87 pub struct ApiErrorResponse {
88 pub error: String,
89 pub code: String,
90 }
91
92 #[derive(Debug, Deserialize)]
93 #[serde(untagged)]
94 pub enum ApiResponse<T> {
95 Ok(T),
96 Error(ApiErrorResponse),
97 }
98}
99#[cfg(test)]
100mod tests {
101 #[test]
102 fn test_client_initialization() {
103 let _client =
104 crate::providers::together::Client::new("dummy-key").expect("Client::new() failed");
105 let _client_from_builder = crate::providers::together::Client::builder()
106 .api_key("dummy-key")
107 .build()
108 .expect("Client::builder() failed");
109 }
110}