use serde::{Deserialize, Serialize};
use super::clip::Clip;
const WEB_CLIENT_PATHNAME: &str = "/create";
const GENERATION_TYPE_TEXT: &str = "TEXT";
const CHALLENGE_TOKEN_PROVIDER: u8 = 1;
#[derive(Debug, Clone, Default)]
pub struct GenerationWebContext {
pub user_tier: Option<String>,
}
impl GenerationWebContext {
fn user_tier_value(&self) -> String {
self.user_tier
.as_deref()
.map(str::trim)
.filter(|tier| !tier.is_empty())
.unwrap_or_default()
.to_string()
}
}
#[derive(Debug, Serialize)]
pub struct GenerateRequest {
pub token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<String>,
pub generation_type: String,
pub title: Option<String>,
pub tags: Option<String>,
pub negative_tags: String,
pub mv: String,
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpt_description_prompt: Option<String>,
pub make_instrumental: bool,
pub user_uploaded_images_b64: Option<String>,
pub metadata: GenerateMetadata,
pub override_fields: Vec<String>,
pub cover_clip_id: Option<String>,
pub cover_start_s: Option<f64>,
pub cover_end_s: Option<f64>,
pub persona_id: Option<String>,
pub artist_clip_id: Option<String>,
pub artist_start_s: Option<f64>,
pub artist_end_s: Option<f64>,
pub continue_clip_id: Option<String>,
pub continued_aligned_prompt: Option<String>,
pub continue_at: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stem_type_id: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stem_type_group_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stem_task: Option<String>,
pub transaction_uuid: String,
pub token_provider: Option<u8>,
}
impl GenerateRequest {
pub fn new(mv: &str, create_mode: &str) -> Self {
Self::new_with_context(mv, create_mode, &GenerationWebContext::default())
}
pub fn new_with_context(mv: &str, create_mode: &str, context: &GenerationWebContext) -> Self {
Self {
token: None,
task: None,
generation_type: GENERATION_TYPE_TEXT.to_string(),
title: None,
tags: None,
negative_tags: String::new(),
mv: mv.to_string(),
prompt: String::new(),
gpt_description_prompt: None,
make_instrumental: false,
user_uploaded_images_b64: None,
metadata: GenerateMetadata::new_with_context(create_mode, context),
override_fields: Vec::new(),
cover_clip_id: None,
cover_start_s: None,
cover_end_s: None,
persona_id: None,
artist_clip_id: None,
artist_start_s: None,
artist_end_s: None,
continue_clip_id: None,
continued_aligned_prompt: None,
continue_at: None,
stem_type_id: None,
stem_type_group_name: None,
stem_task: None,
transaction_uuid: uuid::Uuid::new_v4().to_string(),
token_provider: None,
}
}
pub fn set_challenge_token(&mut self, token: Option<String>) {
self.token = token;
self.token_provider = self.token.as_ref().map(|_| CHALLENGE_TOKEN_PROVIDER);
}
}
#[derive(Debug, Serialize)]
pub struct GenerateMetadata {
pub web_client_pathname: String,
pub is_max_mode: bool,
pub is_mumble: bool,
pub create_mode: String,
pub user_tier: String,
pub create_session_token: String,
pub disable_volume_normalization: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub control_sliders: Option<ControlSliders>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lyrics_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_remix: Option<bool>,
}
impl GenerateMetadata {
fn new_with_context(create_mode: &str, context: &GenerationWebContext) -> Self {
Self {
web_client_pathname: WEB_CLIENT_PATHNAME.to_string(),
is_max_mode: false,
is_mumble: false,
create_mode: create_mode.to_string(),
user_tier: context.user_tier_value(),
create_session_token: uuid::Uuid::new_v4().to_string(),
disable_volume_normalization: false,
control_sliders: None,
lyrics_model: None,
is_remix: None,
}
}
}
#[derive(Debug, Serialize)]
pub struct ControlSliders {
#[serde(skip_serializing_if = "Option::is_none")]
pub weirdness_constraint: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub style_weight: Option<f64>,
}
#[derive(Debug, Deserialize)]
pub struct GenerateResponse {
#[serde(default)]
pub clips: Vec<Clip>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generation_context_sets_shared_web_metadata() {
let context = GenerationWebContext {
user_tier: Some("tier-pro".into()),
};
let request = GenerateRequest::new_with_context("chirp-fenix", "custom", &context);
let body = serde_json::to_value(request).expect("request json");
assert_eq!(body["generation_type"], "TEXT");
assert_eq!(body["metadata"]["web_client_pathname"], "/create");
assert_eq!(body["metadata"]["user_tier"], "tier-pro");
assert!(body["metadata"]["create_session_token"].as_str().is_some());
assert!(body["transaction_uuid"].as_str().is_some());
}
#[test]
fn challenge_token_sets_web_token_provider() {
let mut request = GenerateRequest::new("chirp-fenix", "custom");
request.set_challenge_token(Some("challenge-token".into()));
let body = serde_json::to_value(request).expect("request json");
assert_eq!(body["token"], "challenge-token");
assert_eq!(body["token_provider"], 1);
}
}