riva 1.0.0

Server-backed Rust client for the Riva proxy (YouTube and SoundCloud).
Documentation
use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
use reqwest::{Client, RequestBuilder, Response};
use serde::{Deserialize, Serialize};
#[cfg(feature = "youtube")]
use serde_json::Value;
use url::Url;

use crate::config::RivaConfig;
use crate::error::{ApiErrorBody, RivaError};

#[derive(Debug, Clone)]
pub struct RivaClient {
    http: Client,
    config: RivaConfig,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthResponse {
    pub status: String,
    pub time: String,
}

#[cfg(feature = "youtube")]
#[derive(Debug, Clone, Copy)]
pub enum YoutubeClientType {
    Web,
    Android,
    Ios,
}

#[cfg(feature = "youtube")]
impl YoutubeClientType {
    fn as_str(self) -> &'static str {
        match self {
            Self::Web => "WEB",
            Self::Android => "ANDROID",
            Self::Ios => "IOS",
        }
    }
}

impl RivaClient {
    pub fn new(config: RivaConfig) -> Result<Self, RivaError> {
        let mut headers = HeaderMap::new();

        if let Some(secret) = &config.access_secret {
            let bearer = format!("Bearer {secret}");
            let value = HeaderValue::from_str(&bearer).map_err(|_| {
                RivaError::InvalidInput("access secret contains invalid header bytes".to_owned())
            })?;
            headers.insert(AUTHORIZATION, value);
        }

        let http = Client::builder().default_headers(headers).build()?;

        Ok(Self { http, config })
    }

    pub fn from_env() -> Result<Self, RivaError> {
        let config = RivaConfig::from_env()?;
        Self::new(config)
    }

    pub fn config(&self) -> &RivaConfig {
        &self.config
    }

    pub async fn health(&self) -> Result<HealthResponse, RivaError> {
        let url = self.join_path("health")?;
        let response = self.send(self.http.get(url)).await?;
        Ok(response.json::<HealthResponse>().await?)
    }

    #[cfg(feature = "youtube")]
    pub async fn youtube_info(
        &self,
        id: &str,
        client_type: Option<YoutubeClientType>,
    ) -> Result<Value, RivaError> {
        let id = normalize_youtube_id(id)?;
        let mut url = self.join_path(&format!("providers/youtube/{id}/info"))?;
        if let Some(client_type) = client_type {
            url.query_pairs_mut()
                .append_pair("client_type", client_type.as_str());
        }

        let response = self.send(self.http.get(url)).await?;
        Ok(response.json::<Value>().await?)
    }

    #[cfg(feature = "youtube")]
    pub async fn youtube_stream(
        &self,
        id: &str,
        itag: u32,
        client_type: Option<YoutubeClientType>,
    ) -> Result<Response, RivaError> {
        let id = normalize_youtube_id(id)?;
        let mut url = self.join_path(&format!("providers/youtube/{id}/stream/{itag}"))?;
        if let Some(client_type) = client_type {
            url.query_pairs_mut()
                .append_pair("client_type", client_type.as_str());
        }

        self.send(self.http.get(url)).await
    }

    #[cfg(feature = "soundcloud")]
    pub async fn soundcloud_stream(&self, track_url: &str) -> Result<Response, RivaError> {
        let track_url = normalize_soundcloud_track_url(track_url)?;
        let mut url = self.join_path("providers/soundcloud/stream")?;
        url.query_pairs_mut().append_pair("url", &track_url);
        self.send(self.http.get(url)).await
    }

    fn join_path(&self, path: &str) -> Result<Url, RivaError> {
        self.config
            .base_url
            .join(path)
            .map_err(RivaError::InvalidBaseUrl)
    }

    async fn send(&self, request: RequestBuilder) -> Result<Response, RivaError> {
        let response = request.send().await?;
        if response.status().is_success() {
            return Ok(response);
        }

        let status = response.status();
        let body = response.text().await?;
        let message = serde_json::from_str::<ApiErrorBody>(&body)
            .map(|v| v.error)
            .unwrap_or(body);

        Err(RivaError::Api { status, message })
    }
}

#[cfg(any(feature = "youtube", feature = "soundcloud"))]
fn validate_not_empty(field_name: &str, value: &str) -> Result<(), RivaError> {
    if value.trim().is_empty() {
        return Err(RivaError::InvalidInput(format!(
            "{field_name} cannot be empty"
        )));
    }
    Ok(())
}

#[cfg(feature = "youtube")]
fn normalize_youtube_id(input: &str) -> Result<String, RivaError> {
    let raw = input.trim();
    validate_not_empty("youtube id", raw)?;

    if is_youtube_video_id(raw) {
        return Ok(raw.to_string());
    }

    if let Some(url) = parse_possible_url(raw)
        && let Some(id) = youtube_id_from_url(&url)
    {
        return Ok(id);
    }

    let candidate = raw
        .split(['?', '&', '/', '#'])
        .next()
        .map(str::trim)
        .unwrap_or_default();
    if is_youtube_video_id(candidate) {
        return Ok(candidate.to_string());
    }

    Err(RivaError::InvalidInput(
        "could not extract a valid YouTube video id from input".to_owned(),
    ))
}

#[cfg(feature = "youtube")]
fn youtube_id_from_url(url: &Url) -> Option<String> {
    let host = url.host_str()?.to_ascii_lowercase();

    if host == "youtu.be" {
        return first_path_segment(url)
            .filter(|id| is_youtube_video_id(id))
            .map(ToString::to_string);
    }

    if !host.ends_with("youtube.com") {
        return None;
    }

    if let Some(v) = url.query_pairs().find_map(|(k, v)| {
        if k == "v" || k == "vi" {
            Some(v.into_owned())
        } else {
            None
        }
    }) && is_youtube_video_id(&v)
    {
        return Some(v);
    }

    let mut segments = url.path_segments()?;
    let first = segments.next()?;
    let second = segments.next();

    match first {
        "shorts" | "embed" | "v" | "live" => second
            .filter(|id| is_youtube_video_id(id))
            .map(ToString::to_string),
        _ => None,
    }
}

#[cfg(feature = "youtube")]
fn is_youtube_video_id(value: &str) -> bool {
    value.len() == 11
        && value
            .bytes()
            .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
}

#[cfg(feature = "soundcloud")]
fn normalize_soundcloud_track_url(input: &str) -> Result<String, RivaError> {
    let raw = input.trim();
    validate_not_empty("soundcloud track url", raw)?;

    if let Some(url) = parse_possible_url(raw)
        && is_soundcloud_host(&url)
    {
        return Ok(url.into());
    }

    let cleaned = raw.trim_matches('/');
    let mut parts = cleaned.split('/');
    let artist = parts.next().unwrap_or_default().trim();
    let track = parts.next().unwrap_or_default().trim();
    let has_only_artist_and_track = parts.next().is_none();

    if !artist.is_empty() && !track.is_empty() && has_only_artist_and_track {
        return Ok(format!("https://soundcloud.com/{artist}/{track}"));
    }

    Err(RivaError::InvalidInput(
        "soundcloud input must be a soundcloud URL or in the format artist/track".to_owned(),
    ))
}

#[cfg(any(feature = "youtube", feature = "soundcloud"))]
fn parse_possible_url(input: &str) -> Option<Url> {
    Url::parse(input).ok().or_else(|| {
        let prefixed = format!("https://{input}");
        Url::parse(&prefixed).ok()
    })
}

#[cfg(feature = "youtube")]
fn first_path_segment(url: &Url) -> Option<&str> {
    url.path_segments()?.find(|segment| !segment.is_empty())
}

#[cfg(feature = "soundcloud")]
fn is_soundcloud_host(url: &Url) -> bool {
    url.host_str()
        .map(|h| {
            let host = h.to_ascii_lowercase();
            host == "soundcloud.com" || host.ends_with(".soundcloud.com")
        })
        .unwrap_or(false)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::DEFAULT_RIVA_BASE_URL;

    #[test]
    fn client_defaults_to_public_server() {
        let client = RivaClient::new(RivaConfig::default()).expect("client should build");
        assert_eq!(
            client.config().base_url.as_str(),
            format!("{DEFAULT_RIVA_BASE_URL}/")
        );
    }

    #[cfg(feature = "youtube")]
    #[test]
    fn youtube_client_type_values_match_api() {
        assert_eq!(YoutubeClientType::Web.as_str(), "WEB");
        assert_eq!(YoutubeClientType::Android.as_str(), "ANDROID");
        assert_eq!(YoutubeClientType::Ios.as_str(), "IOS");
    }

    #[cfg(any(feature = "youtube", feature = "soundcloud"))]
    #[test]
    fn validates_non_empty_fields() {
        let result = validate_not_empty("x", "   ");
        assert!(matches!(result, Err(RivaError::InvalidInput(_))));
    }

    #[cfg(feature = "youtube")]
    #[test]
    fn normalizes_youtube_id_from_direct_value() {
        let id = normalize_youtube_id("dQw4w9WgXcQ").expect("id should normalize");
        assert_eq!(id, "dQw4w9WgXcQ");
    }

    #[cfg(feature = "youtube")]
    #[test]
    fn normalizes_youtube_id_from_common_urls() {
        let cases = [
            "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
            "https://youtu.be/dQw4w9WgXcQ?t=43",
            "youtube.com/shorts/dQw4w9WgXcQ",
        ];

        for case in cases {
            let id = normalize_youtube_id(case).expect("url should normalize");
            assert_eq!(id, "dQw4w9WgXcQ");
        }
    }

    #[cfg(feature = "soundcloud")]
    #[test]
    fn normalizes_soundcloud_from_url_or_artist_track() {
        let url = normalize_soundcloud_track_url("https://soundcloud.com/kordhell/trageluxe")
            .expect("url should normalize");
        assert_eq!(url, "https://soundcloud.com/kordhell/trageluxe");

        let from_artist_track = normalize_soundcloud_track_url("kordhell/trageluxe")
            .expect("artist/track should normalize");
        assert_eq!(
            from_artist_track,
            "https://soundcloud.com/kordhell/trageluxe"
        );
    }

    #[test]
    fn can_create_new_base_url() {
        let config = RivaConfig::new("https://riva.resonix.dev").expect("config should be valid");
        let client = RivaClient::new(config).expect("client should build");
        assert_eq!(client.config.base_url.as_str(), "https://riva.resonix.dev/");
    }
}