mod errors;
mod requests;
mod responses;
use std::{io::Write, path::Path};
pub use errors::AllTalkError;
pub use requests::*;
use reqwest::{Client, ClientBuilder};
pub use responses::*;
use url::Url;
#[derive(Clone, Debug)]
pub struct AllTalkClient {
pub client: Client,
pub(crate) baseurl: Url,
}
impl AllTalkClient {
pub fn from_url(baseurl: Url) -> Result<Self, AllTalkError> {
let client_builder = ClientBuilder::new()
.user_agent("alltalk-client/0.1.0")
.pool_max_idle_per_host(20)
.timeout(std::time::Duration::from_secs(60 * 10));
let client = client_builder.build().map_err(|e| AllTalkError::ReqwestError(e))?;
Ok(Self { client, baseurl })
}
pub fn from_environment() -> Result<Self, AllTalkError> {
let baseurl = std::env::var("ALLTALK_URL").expect("ALLTALK_URL is not set");
let baseurl = Url::parse(&baseurl).expect("ALLTALK_URL is not a valid URL");
Self::from_url(baseurl)
}
pub async fn is_ready(&self) -> Result<bool, AllTalkError> {
let url = self.baseurl.join("api/ready")?;
let response = self.client.get(url).send().await?;
if !response.status().is_success() {
return Ok(false);
}
let response_text = response.text().await?;
Ok(response_text == "Ready")
}
pub async fn get_voices(&self) -> Result<Vec<String>, AllTalkError> {
let url = self.baseurl.join("api/voices")?;
let response = self.client.get(url).send().await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let voices: VoicesResponse = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("VoicesResponse"),
}
})?;
Ok(voices.voices)
}
pub async fn get_current_settings(&self) -> Result<CurrentSettingsResponse, AllTalkError> {
let url = self.baseurl.join("api/currentsettings")?;
let response = self.client.get(url).send().await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let current_settings: CurrentSettingsResponse = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("CurrentSettingsResponse"),
}
})?;
Ok(current_settings)
}
pub async fn preview_voice(&self, voice: &str) -> Result<String, AllTalkError> {
let url = self.baseurl.join("api/previewvoice/")?;
let response = self.client.post(url).form(&[("voice", voice)]).send().await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let response_json: serde_json::Value = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("serde_json::Value"),
}
})?;
let output_file_url = response_json["output_file_url"]
.as_str()
.ok_or(AllTalkError::ResponseError("output_file_url not found".to_string()))?;
Ok(output_file_url.to_string())
}
pub async fn switch_model(&self, tts_method: &str) -> Result<String, AllTalkError> {
let url = self.baseurl.join("api/reload")?;
let response = self
.client
.post(url)
.query(&[("tts_method", tts_method)])
.send()
.await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let status_response: StatusResponse = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("StatusResponse"),
}
})?;
Ok(status_response.status)
}
pub async fn switch_deepspeed(&self, new_deepspeed_value: bool) -> Result<String, AllTalkError> {
let url = self.baseurl.join("api/deepspeed")?;
let response = self
.client
.post(url)
.query(&[("new_deepspeed_value", new_deepspeed_value.to_string())])
.send()
.await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let status_response: StatusResponse = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("StatusResponse"),
}
})?;
Ok(status_response.status)
}
pub async fn switch_low_vram_setting(&self, new_low_vram_value: bool) -> Result<String, AllTalkError> {
let url = self.baseurl.join("api/lowvramsetting")?;
let response = self
.client
.post(url)
.query(&[("new_low_vram_value", new_low_vram_value.to_string())])
.send()
.await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let status_response: StatusResponse = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("StatusResponse"),
}
})?;
Ok(status_response.status)
}
pub async fn generate_tts_from_parameters(
&self,
text_input: &str,
text_filtering: TextFiltering,
character_voice_gen: &str,
narrator_enabled: bool,
narrator_voice_gen: &str,
text_not_inside: TextNotInside,
language: Language,
output_file_name: &str,
output_file_timestamp: bool,
autoplay: bool,
autoplay_volume: f32,
) -> Result<TTSGenerationResponse, AllTalkError> {
let url = self.baseurl.join("api/tts-generate")?;
let response = self
.client
.post(url)
.form(&[
("text_input", text_input),
("text_filtering", text_filtering.into()),
("character_voice_gen", character_voice_gen),
("narrator_enabled", &narrator_enabled.to_string()),
("narrator_voice_gen", narrator_voice_gen),
("text_not_inside", text_not_inside.into()),
("language", language.into()),
("output_file_name", output_file_name),
("output_file_timestamp", &output_file_timestamp.to_string()),
("autoplay", &autoplay.to_string()),
("autoplay_volume", &autoplay_volume.to_string()),
])
.send()
.await?;
if !response.status().is_success() {
let response_text = response.text().await?;
return Err(AllTalkError::ResponseError(response_text));
}
let response_text = response.text().await?;
let response: TTSGenerationResponse = serde_json::from_str(&response_text).map_err(|_e| {
AllTalkError::JsonParserError {
json: response_text,
target: String::from("TTSGenerationResponse"),
}
})?;
Ok(response)
}
pub async fn generate_tts(&self, options: &TTSModelOptions) -> Result<TTSGenerationResponse, AllTalkError> {
self
.generate_tts_from_parameters(
&options.text_input,
options.text_filtering.clone(),
&options.character_voice_gen,
options.narrator_enabled,
&options.narrator_voice_gen,
options.text_not_inside.clone(),
options.language.clone(),
&options.output_file_name,
options.output_file_timestamp,
options.autoplay,
options.autoplay_volume,
)
.await
}
pub async fn download_file(&self, url: &str, destination: &str) -> Result<String, AllTalkError> {
let url = self.baseurl.join(url)?;
let response = self.client.get(url).send().await?;
let full_path = Path::new(destination);
let parent_path = full_path.parent().unwrap();
if !parent_path.exists() {
std::fs::create_dir(parent_path)?;
}
let mut file = std::fs::File::create(&full_path)?;
let content = response.bytes().await?;
std::io::copy(&mut content.as_ref(), &mut file)?;
file.flush()?;
Ok(destination.to_owned())
}
}