1use super::completion::{CompletionRequest, CompletionResponse};
2use super::embedding::{
3 BatchSemanticEmbeddingRequest, BatchSemanticEmbeddingResponse, EmbeddingRequest,
4 EmbeddingResponse, SemanticEmbeddingRequest, SemanticEmbeddingResponse,
5};
6use super::error::ApiError;
7use super::evaluate::{EvaluationRequest, EvaluationResponse};
8use super::explanation::{ExplanationRequest, ExplanationResponse};
9use super::http;
10use super::tokenization::{
11 DetokenizationRequest, DetokenizationResponse, TokenizationRequest, TokenizationResponse,
12};
13use bytes::Bytes;
14use tokenizers::Tokenizer;
15
16pub struct Client {
17 http_client: reqwest::Client,
18 pub base_url: String,
19 pub api_token: String,
20}
21
22pub const ALEPH_ALPHA_API_BASE_URL: &str = "https://api.aleph-alpha.com";
23
24impl Client {
25 pub fn new(api_token: String) -> Result<Self, ApiError> {
27 Self::new_with_base_url(ALEPH_ALPHA_API_BASE_URL.to_owned(), api_token)
28 }
29
30 pub fn new_with_base_url(base_url: String, api_token: String) -> Result<Self, ApiError> {
33 Ok(Self {
34 http_client: http::create_client(&api_token)?,
35 base_url,
36 api_token,
37 })
38 }
39
40 pub async fn post<I: serde::ser::Serialize, O: serde::de::DeserializeOwned>(
41 &self,
42 path: &str,
43 data: &I,
44 query: Option<Vec<(String, String)>>,
45 ) -> Result<O, ApiError> {
46 use reqwest::header::{ACCEPT, CONTENT_TYPE};
47
48 let url = format!("{base_url}{path}", base_url = self.base_url, path = path);
49 let mut request = self.http_client.post(url);
50
51 if let Some(q) = query {
52 request = request.query(&q);
53 }
54
55 let request = request
56 .header(CONTENT_TYPE, "application/json")
57 .header(ACCEPT, "application/json")
58 .json(data);
59
60 let response = request.send().await?;
61 let response = http::translate_http_error(response).await?;
62 let response_body: O = response.json().await?;
63 Ok(response_body)
64 }
65
66 pub async fn post_nice<I: serde::ser::Serialize, O: serde::de::DeserializeOwned>(
67 &self,
68 path: &str,
69 data: &I,
70 nice: Option<bool>,
71 ) -> Result<O, ApiError> {
72 let query = if let Some(be_nice) = nice {
73 Some(vec![("nice".to_owned(), be_nice.to_string())])
74 } else {
75 None
76 };
77 Ok(self.post(path, data, query).await?)
78 }
79
80 pub async fn get<O: serde::de::DeserializeOwned>(&self, path: &str) -> Result<O, ApiError> {
81 let response = http::get(&self.http_client, &self.base_url, path, None).await?;
82 let response_body = response.json().await?;
83 Ok(response_body)
84 }
85
86 pub async fn get_string(&self, path: &str) -> Result<String, ApiError> {
87 let response = http::get(&self.http_client, &self.base_url, path, None).await?;
88 let response_body = response.text().await?;
89 Ok(response_body)
90 }
91
92 pub async fn get_binary(&self, path: &str) -> Result<Bytes, ApiError> {
93 let response = http::get(&self.http_client, &self.base_url, path, None).await?;
94 let response_body = response.bytes().await?;
95 Ok(response_body)
96 }
97
98 pub async fn completion(
124 &self,
125 req: &CompletionRequest,
126 nice: Option<bool>,
127 ) -> Result<CompletionResponse, ApiError> {
128 Ok(self.post_nice("/complete", req, nice).await?)
129 }
130
131 pub async fn evaluate(
133 &self,
134 req: &EvaluationRequest,
135 nice: Option<bool>,
136 ) -> Result<EvaluationResponse, ApiError> {
137 Ok(self.post_nice("/evaluate", req, nice).await?)
138 }
139
140 pub async fn explain(
142 &self,
143 req: &ExplanationRequest,
144 nice: Option<bool>,
145 ) -> Result<ExplanationResponse, ApiError> {
146 Ok(self.post_nice("/explain", req, nice).await?)
147 }
148
149 pub async fn embed(
151 &self,
152 req: &EmbeddingRequest,
153 nice: Option<bool>,
154 ) -> Result<EmbeddingResponse, ApiError> {
155 Ok(self.post_nice("/embed", req, nice).await?)
156 }
157
158 pub async fn semantic_embed(
160 &self,
161 req: &SemanticEmbeddingRequest,
162 nice: Option<bool>,
163 ) -> Result<SemanticEmbeddingResponse, ApiError> {
164 Ok(self.post_nice("/semantic_embed", req, nice).await?)
165 }
166
167 pub async fn batch_semantic_embed(
169 &self,
170 req: &BatchSemanticEmbeddingRequest,
171 nice: Option<bool>,
172 ) -> Result<BatchSemanticEmbeddingResponse, ApiError> {
173 Ok(self.post_nice("/batch_semantic_embed", req, nice).await?)
174 }
175
176 pub async fn tokenize(
178 &self,
179 req: &TokenizationRequest,
180 ) -> Result<TokenizationResponse, ApiError> {
181 Ok(self.post("/tokenize", req, None).await?)
182 }
183
184 pub async fn detokenize(
186 &self,
187 req: &DetokenizationRequest,
188 ) -> Result<DetokenizationResponse, ApiError> {
189 Ok(self.post("/detokenize", req, None).await?)
190 }
191
192 pub async fn get_tokenizer_binary(&self, model: &str) -> Result<Bytes, ApiError> {
193 let path = format!("/models/{model}/tokenizer");
194 let vocabulary = self.get_binary(&path).await?;
195 Ok(vocabulary)
196 }
197
198 pub async fn get_tokenizer(&self, model: &str) -> Result<Tokenizer, ApiError> {
199 let vocabulary = self.get_tokenizer_binary(model).await?;
200 let tokenizer = Tokenizer::from_bytes(vocabulary)?;
201 Ok(tokenizer)
202 }
203
204 pub async fn get_version(&self) -> Result<String, ApiError> {
206 Ok(self.get_string("/version").await?)
207 }
208}