alltalk 0.1.0

A client for the AllTalk API
Documentation
mod errors;
mod requests;
mod responses;

use std::{io::Write, path::Path};

pub use errors::AllTalkError;
pub use requests::*;
use reqwest::{Client, ClientBuilder};
pub use responses::*;
use url::Url;

#[derive(Clone, Debug)]
pub struct AllTalkClient {
  pub client: Client,
  pub(crate) baseurl: Url,
}

impl AllTalkClient {
  pub fn from_url(baseurl: Url) -> Result<Self, AllTalkError> {
    let client_builder = ClientBuilder::new()
      .user_agent("alltalk-client/0.1.0")
      .pool_max_idle_per_host(20)
      .timeout(std::time::Duration::from_secs(60 * 10));
    let client = client_builder.build().map_err(|e| AllTalkError::ReqwestError(e))?;
    Ok(Self { client, baseurl })
  }

  pub fn from_environment() -> Result<Self, AllTalkError> {
    let baseurl = std::env::var("ALLTALK_URL").expect("ALLTALK_URL is not set");
    let baseurl = Url::parse(&baseurl).expect("ALLTALK_URL is not a valid URL");
    Self::from_url(baseurl)
  }

  // Check if the Text-to-Speech (TTS) service is ready to accept requests.
  // curl -X GET "http://127.0.0.1:7851/api/ready"

  pub async fn is_ready(&self) -> Result<bool, AllTalkError> {
    let url = self.baseurl.join("api/ready")?;
    let response = self.client.get(url).send().await?;
    if !response.status().is_success() {
      return Ok(false);
    }
    let response_text = response.text().await?;
    Ok(response_text == "Ready")
  }

  //Retrieve a list of available voices for generating speech.

  // URL: http://127.0.0.1:7851/api/voices
  // - Method: GET

  // curl -X GET "http://127.0.0.1:7851/api/voices"

  // JSON return: {"voices": ["voice1.wav", "voice2.wav", "voice3.wav"]}
  pub async fn get_voices(&self) -> Result<Vec<String>, AllTalkError> {
    let url = self.baseurl.join("api/voices")?;
    let response = self.client.get(url).send().await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    let voices: VoicesResponse = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("VoicesResponse"),
      }
    })?;
    Ok(voices.voices)
  }

  //   Retrieve a list of current settings.

  // URL: http://127.0.0.1:7851/api/currentsettings
  // - Method: GET

  // curl -X GET "http://127.0.0.1:7851/api/currentsettings"

  // JSON return: {"models_available":[{"name":"Coqui","model_name":"API TTS"},{"name":"Coqui","model_name":"API Local"},{"name":"Coqui","model_name":"XTTSv2 Local"}],"current_model_loaded":"XTTSv2 Local","deepspeed_available":true,"deepspeed_status":true,"low_vram_status":true,"finetuned_model":false}

  // name & model_name = listing the currently available models.
  // current_model_loaded = what model is currently loaded into VRAM.
  // deepspeed_available = was DeepSpeed detected on startup and available to be
  // activated. deepspeed_status = If DeepSpeed was detected, is it currently
  // activated. low_vram_status = Is Low VRAM currently enabled.
  // finetuned_model = Was a finetuned model detected. (XTTSv2 FT).

  pub async fn get_current_settings(&self) -> Result<CurrentSettingsResponse, AllTalkError> {
    let url = self.baseurl.join("api/currentsettings")?;
    let response = self.client.get(url).send().await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    // println!("{}", response_text);
    let current_settings: CurrentSettingsResponse = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("CurrentSettingsResponse"),
      }
    })?;
    Ok(current_settings)
  }

  //   Generate a preview of a specified voice with hardcoded settings.

  // URL: http://127.0.0.1:7851/api/previewvoice/
  // - Method: POST
  // - Content-Type: application/x-www-form-urlencoded

  // curl -X POST "http://127.0.0.1:7851/api/previewvoice/" -F "voice=female_01.wav"

  // Replace female_01.wav with the name of the voice sample you want to hear.

  // JSON return: {"status": "generate-success", "output_file_path": "/path/to/outputs/api_preview_voice.wav", "output_file_url": "http://127.0.0.1:7851/audio/api_preview_voice.wav"}
  pub async fn preview_voice(&self, voice: &str) -> Result<String, AllTalkError> {
    let url = self.baseurl.join("api/previewvoice/")?;
    let response = self.client.post(url).form(&[("voice", voice)]).send().await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    let response_json: serde_json::Value = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("serde_json::Value"),
      }
    })?;
    let output_file_url = response_json["output_file_url"]
      .as_str()
      .ok_or(AllTalkError::ResponseError("output_file_url not found".to_string()))?;
    Ok(output_file_url.to_string())
  }

  // Swithc model endpoint
  //   URL: http://127.0.0.1:7851/api/reload
  // - Method: POST

  // curl -X POST "http://127.0.0.1:7851/api/reload?tts_method=API%20Local"
  // curl -X POST "http://127.0.0.1:7851/api/reload?tts_method=API%20TTS"
  // curl -X POST "http://127.0.0.1:7851/api/reload?tts_method=XTTSv2%20Local"
  // Switch between the 3 models respectively.

  // curl -X POST "http://127.0.0.1:7851/api/reload?tts_method=XTTSv2%20FT"

  // If you have a finetuned model in /models/trainedmodel/ (will error otherwise)

  // JSON return {"status": "model-success"}

  pub async fn switch_model(&self, tts_method: &str) -> Result<String, AllTalkError> {
    let url = self.baseurl.join("api/reload")?;
    let response = self
      .client
      .post(url)
      .query(&[("tts_method", tts_method)])
      .send()
      .await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    let status_response: StatusResponse = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("StatusResponse"),
      }
    })?;
    Ok(status_response.status)
  }

  // switch DeppSpeed endpoint
  // URL: http://127.0.0.1:7851/api/deepspeed
  // - Method: POST

  // curl -X POST "http://127.0.0.1:7851/api/deepspeed?new_deepspeed_value=True"

  // Replace True with False to disable DeepSpeed mode.

  // JSON return {"status": "deepspeed-success"}
  pub async fn switch_deepspeed(&self, new_deepspeed_value: bool) -> Result<String, AllTalkError> {
    let url = self.baseurl.join("api/deepspeed")?;
    let response = self
      .client
      .post(url)
      .query(&[("new_deepspeed_value", new_deepspeed_value.to_string())])
      .send()
      .await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    let status_response: StatusResponse = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("StatusResponse"),
      }
    })?;
    Ok(status_response.status)
  }

  // Switch Low VRAM endpoint
  // URL: http://127.0.0.1:7851/api/lowvramsetting
  // - Method: POST

  // curl -X POST "http://127.0.0.1:7851/api/lowvramsetting?new_low_vram_value=True"

  // Replace True with False to disable Low VRAM mode.

  // JSON return {"status": "lowvram-success"}
  pub async fn switch_low_vram_setting(&self, new_low_vram_value: bool) -> Result<String, AllTalkError> {
    let url = self.baseurl.join("api/lowvramsetting")?;
    let response = self
      .client
      .post(url)
      .query(&[("new_low_vram_value", new_low_vram_value.to_string())])
      .send()
      .await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    let status_response: StatusResponse = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("StatusResponse"),
      }
    })?;
    Ok(status_response.status)
  }

  //   Generate speech from text.
  // Standard TTS generation supports Narration and will generate a wav file/blob.
  // Standard TTS speech Example (standard text) generating a time-stamped file

  // curl -X POST "http://127.0.0.1:7851/api/tts-generate" -d "text_input=All of this is text spoken by the character. This is text not inside quotes, though that doesnt matter in the slightest" -d "text_filtering=standard" -d "character_voice_gen=female_01.wav" -d "narrator_enabled=false" -d "narrator_voice_gen=male_01.wav" -d "text_not_inside=character" -d "language=en" -d "output_file_name=myoutputfile" -d "output_file_timestamp=true" -d "autoplay=true" -d "autoplay_volume=0.8"
  pub async fn generate_tts_from_parameters(
    &self,
    text_input: &str,
    text_filtering: TextFiltering,
    character_voice_gen: &str,
    narrator_enabled: bool,
    narrator_voice_gen: &str,
    text_not_inside: TextNotInside,
    language: Language,
    output_file_name: &str,
    output_file_timestamp: bool,
    autoplay: bool,
    autoplay_volume: f32,
  ) -> Result<TTSGenerationResponse, AllTalkError> {
    let url = self.baseurl.join("api/tts-generate")?;
    let response = self
      .client
      .post(url)
      .form(&[
        ("text_input", text_input),
        ("text_filtering", text_filtering.into()),
        ("character_voice_gen", character_voice_gen),
        ("narrator_enabled", &narrator_enabled.to_string()),
        ("narrator_voice_gen", narrator_voice_gen),
        ("text_not_inside", text_not_inside.into()),
        ("language", language.into()),
        ("output_file_name", output_file_name),
        ("output_file_timestamp", &output_file_timestamp.to_string()),
        ("autoplay", &autoplay.to_string()),
        ("autoplay_volume", &autoplay_volume.to_string()),
      ])
      .send()
      .await?;
    if !response.status().is_success() {
      let response_text = response.text().await?;
      return Err(AllTalkError::ResponseError(response_text));
    }
    let response_text = response.text().await?;
    let response: TTSGenerationResponse = serde_json::from_str(&response_text).map_err(|_e| {
      AllTalkError::JsonParserError {
        json: response_text,
        target: String::from("TTSGenerationResponse"),
      }
    })?;
    Ok(response)
  }

  // Generate speech from text.
  pub async fn generate_tts(&self, options: &TTSModelOptions) -> Result<TTSGenerationResponse, AllTalkError> {
    self
      .generate_tts_from_parameters(
        &options.text_input,
        options.text_filtering.clone(),
        &options.character_voice_gen,
        options.narrator_enabled,
        &options.narrator_voice_gen,
        options.text_not_inside.clone(),
        options.language.clone(),
        &options.output_file_name,
        options.output_file_timestamp,
        options.autoplay,
        options.autoplay_volume,
      )
      .await
  }

  pub async fn download_file(&self, url: &str, destination: &str) -> Result<String, AllTalkError> {
    let url = self.baseurl.join(url)?;

    let response = self.client.get(url).send().await?;

    let full_path = Path::new(destination);
    let parent_path = full_path.parent().unwrap();
    if !parent_path.exists() {
      std::fs::create_dir(parent_path)?;
    }
    // we write the response to a file
    let mut file = std::fs::File::create(&full_path)?;
    let content = response.bytes().await?;
    std::io::copy(&mut content.as_ref(), &mut file)?;
    file.flush()?;

    Ok(destination.to_owned())
  }
}

// #[cfg(test)]
// mod tests {
//   use wiremock::{
//     matchers::{method, path},
//     Mock, MockServer, ResponseTemplate,
//   };

//   use super::*;

//   #[tokio::test]
//   async fn test_is_ready() {
//     let mock_server = MockServer::start().await;
//     Mock::given(method("GET"))
//       .and(path("/api/ready"))
//       .respond_with(ResponseTemplate::new(200))
//       .mount(&mock_server)
//       .await;

//     let client = AllTalkClient::from_url(mock_server.uri()).unwrap();
//     let result = client.is_ready().await.unwrap();
//     assert_eq!(result, true);
//   }
// }