use crate::protocol as p;
use anyhow::Result;
use tokio_tungstenite::tungstenite as ws;
use tokio_tungstenite::tungstenite::Utf8Bytes;
pub type WebSocket =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
pub type WebSocketSender = futures_util::stream::SplitSink<WebSocket, ws::Message>;
pub type WebSocketReceiver = futures_util::stream::SplitStream<WebSocket>;
pub const DEFAULT_API_SOURCE: &str = "rust-client";
#[derive(Clone)]
pub struct Client {
api_key: String,
server_addr: String,
api_source: Option<String>,
use_https: bool,
path: String,
additional_headers: Vec<(String, String)>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Location {
Default,
EU,
US,
}
impl Location {
pub fn as_str(&self) -> &'static str {
match self {
Location::Default => "default",
Location::EU => "eu",
Location::US => "us",
}
}
pub fn server_addr(&self) -> &'static str {
match self {
Location::Default => "api.gradium.ai",
Location::EU => "eu.api.gradium.ai",
Location::US => "us.api.gradium.ai",
}
}
}
impl Client {
pub fn from_location(api_key: &str, location: Location) -> Self {
Client {
api_key: api_key.to_string(),
server_addr: location.server_addr().to_string(),
use_https: true,
path: "api".to_string(),
additional_headers: Vec::new(),
api_source: None,
}
}
pub fn new(api_key: &str) -> Self {
Self::from_location(api_key, Location::Default)
}
pub fn us_prod(api_key: &str) -> Self {
Client {
api_key: api_key.to_string(),
server_addr: "us.api.gradium.ai".to_string(),
use_https: true,
path: "api".to_string(),
additional_headers: Vec::new(),
api_source: None,
}
}
pub fn eu_prod(api_key: &str) -> Self {
Client {
api_key: api_key.to_string(),
server_addr: "eu.api.gradium.ai".to_string(),
use_https: true,
path: "api".to_string(),
additional_headers: Vec::new(),
api_source: None,
}
}
pub fn with_additional_header(mut self, key: &str, value: &str) -> Self {
self.additional_headers.push((key.to_string(), value.to_string()));
self
}
pub fn with_api_source(mut self, api_source: String) -> Self {
self.api_source = Some(api_source);
self
}
pub fn from_env(base_url: Option<String>, api_key: Option<String>) -> Result<Self> {
let api_key = match api_key {
None => match crate::api_key_from_env() {
None => anyhow::bail!("API key not provided and GRADIUM_API_KEY not set"),
Some(key) => key,
},
Some(key) => key.to_string(),
};
let client = Client::new(&api_key);
let client = match base_url {
None => match crate::base_url_from_env() {
None => client,
Some(base_url) => client.with_base_url(&base_url)?,
},
Some(base_url) => client.with_base_url(&base_url)?,
};
Ok(client)
}
pub fn with_api_key(mut self, api_key: &str) -> Self {
self.api_key = api_key.to_string();
self
}
pub fn with_server_addr(mut self, server_addr: &str) -> Self {
self.server_addr = server_addr.to_string();
self
}
pub fn with_https(mut self, use_https: bool) -> Self {
self.use_https = use_https;
self
}
pub fn with_path(mut self, path: &str) -> Self {
self.path = path.to_string();
self
}
pub fn with_base_url(mut self, base_url: &str) -> Result<Self> {
let url = url::Url::parse(base_url)?;
self.server_addr = url.host_str().unwrap_or(Location::EU.server_addr()).to_string();
if let Some(port) = url.port() {
self.server_addr = format!("{}:{}", self.server_addr, port);
}
self.use_https = url.scheme() == "https";
self.path = url.path().trim_start_matches('/').to_string();
Ok(self)
}
pub fn ws_url(&self, endpoint: &str) -> String {
let protocol = if self.use_https { "wss" } else { "ws" };
if self.path.is_empty() {
format!("{protocol}://{}/{endpoint}", self.server_addr)
} else {
format!("{protocol}://{}/{}/{endpoint}", self.server_addr, self.path)
}
}
pub fn http_url(&self, endpoint: &str) -> String {
let protocol = if self.use_https { "https" } else { "http" };
if self.path.is_empty() {
format!("{protocol}://{}/{endpoint}", self.server_addr)
} else {
format!("{protocol}://{}/{}/{endpoint}", self.server_addr, self.path)
}
}
pub async fn ws_connect(&self, endpoint: &str) -> Result<WebSocket> {
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
let url = self.ws_url(endpoint);
let mut request = url.into_client_request()?;
let headers = request.headers_mut();
headers.insert("x-api-key", HeaderValue::from_str(&self.api_key)?);
let api_source = self.api_source.as_deref().unwrap_or(DEFAULT_API_SOURCE);
headers.insert("x-api-source", HeaderValue::from_str(api_source)?);
for (key, value) in self.additional_headers.iter() {
let key = ws::http::header::HeaderName::from_bytes(key.as_bytes())?;
headers.insert(key, HeaderValue::from_str(value.as_str())?);
}
let (ws, _response) =
tokio_tungstenite::connect_async_with_config(request, None, true).await?;
Ok(ws)
}
pub async fn tts(&self, text: &str, setup: p::tts::Setup) -> Result<crate::tts::TtsResult> {
crate::tts::tts(text, setup, self).await
}
pub async fn tts_stream(&self, setup: p::tts::Setup) -> Result<crate::tts::TtsStream> {
crate::tts::tts_stream(setup, self).await
}
pub async fn tts_multiplex(&self) -> Result<crate::tts::TtsMultiplexStream> {
crate::tts::TtsMultiplexStream::connect(self).await
}
pub async fn stt(&self, audio: Vec<u8>, setup: p::stt::Setup) -> Result<crate::stt::SttResult> {
crate::stt::stt(audio, setup, self).await
}
pub async fn stt_stream(&self, setup: p::stt::Setup) -> Result<crate::stt::SttStream> {
crate::stt::stt_stream(setup, self).await
}
pub(crate) async fn get(&self, endpoint: &str) -> Result<serde_json::Value> {
let url = self.http_url(endpoint);
let api_source = self.api_source.as_deref().unwrap_or(DEFAULT_API_SOURCE);
let response = reqwest::Client::new()
.get(&url)
.header("x-api-key", &self.api_key)
.header("x-api-source", api_source)
.send()
.await?;
let response = response.error_for_status()?;
Ok(response.json().await?)
}
pub async fn credits(&self) -> Result<crate::protocol::CreditsResponse> {
let v = self.get("usages/credits").await?;
let credits: crate::protocol::CreditsResponse = serde_json::from_value(v)?;
Ok(credits)
}
pub async fn usage(&self) -> Result<crate::protocol::UsageResponse> {
let v = self.get("usages/summary").await?;
let usage: crate::protocol::UsageResponse = serde_json::from_value(v)?;
Ok(usage)
}
}
pub(crate) async fn next_message(ws: &mut WebSocket) -> Result<Option<Utf8Bytes>> {
use futures_util::SinkExt;
use futures_util::StreamExt;
use tokio_tungstenite::tungstenite::Message;
let msg = loop {
let msg = ws.next().await;
match msg {
None => return Ok(None),
Some(Err(e)) => anyhow::bail!("websocket error: {}", e),
Some(Ok(Message::Binary(_))) => anyhow::bail!("unexpected binary message"),
Some(Ok(Message::Close(_close_frame))) => {
return Ok(None);
}
Some(Ok(Message::Text(text))) => break text,
Some(Ok(Message::Ping(_))) => ws.send(Message::Pong(vec![].into())).await?,
Some(Ok(Message::Pong(_) | Message::Frame(_))) => {}
};
};
Ok(Some(msg))
}
pub(crate) async fn next_message_receiver(ws: &mut WebSocketReceiver) -> Result<Option<Utf8Bytes>> {
use futures_util::StreamExt;
use tokio_tungstenite::tungstenite::Message;
let msg = loop {
let msg = ws.next().await;
match msg {
None => return Ok(None),
Some(Err(e)) => anyhow::bail!("websocket error: {}", e),
Some(Ok(Message::Binary(_))) => anyhow::bail!("unexpected binary message"),
Some(Ok(Message::Close(_close_frame))) => {
return Ok(None);
}
Some(Ok(Message::Text(text))) => break text,
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_) | Message::Frame(_))) => {}
};
};
Ok(Some(msg))
}