llmservice-flows 0.5.0

LLM Service integration for flows.network
Documentation
use reqwest::multipart;
use serde::Deserialize;

use crate::LLMApi;
use crate::Retry;

pub struct TranscribeInput {
    pub audio: Vec<u8>,
    pub audio_format: String,
    pub language: String,
    pub max_len: Option<u64>,
    pub max_context: Option<i32>,
    pub split_on_word: Option<bool>,
}

impl LLMApi for TranscribeInput {
    type Output = TranscriptionOutput;
    async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
        transcribe_inner(endpoint, api_key, &self).await
    }
}

#[derive(Debug, Deserialize)]
pub struct TranscriptionOutput {
    pub text: String,
}
pub struct TranslateInput {
    pub audio: Vec<u8>,
    pub audio_format: String,
    pub language: String,
    pub max_len: Option<u64>,
    pub max_context: Option<i32>,
    pub split_on_word: Option<bool>,
}

impl LLMApi for TranslateInput {
    type Output = TranslationOutput;
    async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
        translate_inner(endpoint, api_key, &self).await
    }
}

#[derive(Debug, Deserialize)]
pub struct TranslationOutput {
    pub text: String,
}

impl<'a> crate::LLMServiceFlows<'a> {
    /// Transcribe audio into the input language.
    ///
    /// `input` is an [TranscribeInput] object.
    ///
    ///```rust
    ///   // This code snippet transcribe input audio into English, the audio is collected in previous step.
    ///   // Prepare the TranscribeInput struct.
    ///   let input = TranscribeInput {
    ///      audio: audio,
    ///      audio_format: "wav".to_string(),
    ///      language: "en".to_string(),
    ///      max_len: Some(0),
    ///      max_context: Some(-1),
    ///      split_on_word: Some(false),
    ///   };
    ///   // Call the transcribe function.
    ///   let transcription = match llm.transcribe(input).await {
    ///       Ok(r) => r.text,
    ///       Err(e) => {your error handling},
    ///   };
    /// ```
    pub async fn transcribe(&self, input: TranscribeInput) -> Result<TranscriptionOutput, String> {
        self.keep_trying(input).await
    }

    /// Translate audio into English.
    ///
    /// `input` is an [TranslateInput] object.
    ///
    ///```rust
    ///   // This code snippet translate input audio into English, the audio is collected in previous step.
    ///   // Prepare the TranslateInput struct.
    ///   let input = TranslateInput {
    ///      audio: audio,
    ///      audio_format: "wav".to_string(),
    ///      language: "zh".to_string(),
    ///      max_len: Some(0),
    ///      max_context: Some(-1),
    ///      split_on_word: Some(false),
    ///   };
    ///   // Call the translate function.
    ///   let translation = match llm.translate(input).await {
    ///       Ok(r) => r.text,
    ///       Err(e) => {your error handling},
    ///   };
    /// ```
    pub async fn translate(&self, input: TranslateInput) -> Result<TranslationOutput, String> {
        self.keep_trying(input).await
    }
}

async fn transcribe_inner(
    endpoint: &str,
    _api_key: &str,
    input: &TranscribeInput,
) -> Retry<TranscriptionOutput> {
    let uri = format!("{}/audio/transcriptions", endpoint);

    let mut form = multipart::Form::new()
        .part(
            "file",
            multipart::Part::bytes(input.audio.clone())
                .file_name(format!("audio.{}", input.audio_format)),
        )
        .part("language", multipart::Part::text(input.language.clone()));
    if input.max_len.is_some() {
        form = form.part(
            "max_len",
            multipart::Part::text(input.max_len.unwrap().to_string()),
        );
    }
    if input.max_context.is_some() {
        form = form.part(
            "max_context",
            multipart::Part::text(input.max_context.unwrap().to_string()),
        );
    }
    if input.split_on_word.is_some() {
        form = form.part(
            "split_on_word",
            multipart::Part::text(input.split_on_word.unwrap().to_string()),
        );
    }

    match reqwest::Client::new()
        .post(uri)
        .multipart(form)
        .send()
        .await
    {
        Ok(res) => {
            let status = res.status();
            let body = res.bytes().await.unwrap();
            match status.is_success() {
                true => Retry::No(
                    serde_json::from_slice::<TranscriptionOutput>(&body.as_ref())
                        .or(Err(String::from("Unexpected error"))),
                ),
                false => {
                    match status.into() {
                        409 | 429 | 503 => {
                            // 409 TryAgain 429 RateLimitError
                            // 503 ServiceUnavailable
                            Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
                        }
                        _ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
                    }
                }
            }
        }
        Err(e) => Retry::No(Err(e.to_string())),
    }
}

async fn translate_inner(
    endpoint: &str,
    _api_key: &str,
    input: &TranslateInput,
) -> Retry<TranslationOutput> {
    let uri = format!("{}/audio/translations", endpoint);

    let mut form = multipart::Form::new()
        .part(
            "file",
            multipart::Part::bytes(input.audio.clone())
                .file_name(format!("audio.{}", input.audio_format)),
        )
        .part("language", multipart::Part::text(input.language.clone()));
    if input.max_len.is_some() {
        form = form.part(
            "max_len",
            multipart::Part::text(input.max_len.unwrap().to_string()),
        );
    }
    if input.max_context.is_some() {
        form = form.part(
            "max_context",
            multipart::Part::text(input.max_context.unwrap().to_string()),
        );
    }
    if input.split_on_word.is_some() {
        form = form.part(
            "split_on_word",
            multipart::Part::text(input.split_on_word.unwrap().to_string()),
        );
    }

    match reqwest::Client::new()
        .post(uri)
        .multipart(form)
        .send()
        .await
    {
        Ok(res) => {
            let status = res.status();
            let body = res.bytes().await.unwrap();
            match status.is_success() {
                true => Retry::No(
                    serde_json::from_slice::<TranslationOutput>(&body.as_ref())
                        .or(Err(String::from("Unexpected error"))),
                ),
                false => {
                    match status.into() {
                        409 | 429 | 503 => {
                            // 409 TryAgain 429 RateLimitError
                            // 503 ServiceUnavailable
                            Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
                        }
                        _ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
                    }
                }
            }
        }
        Err(e) => Retry::No(Err(e.to_string())),
    }
}