use anyhow::{Context, Result};
use clap::Parser;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::client::{build_client, post_json};
use crate::config::canonical_base_url;
use crate::output::print_json;
#[derive(Debug, Parser)]
pub struct BatchCmd {
pub file: std::path::PathBuf,
#[arg(long, default_value = "verbose")]
pub format: String,
}
#[derive(Debug, Deserialize)]
struct BatchFile {
requests: Vec<BatchRequest>,
}
#[derive(Debug, Serialize, Deserialize)]
struct BatchRequest {
name: String,
symbol: String,
#[serde(default = "default_tf")]
tf: String,
#[serde(default)]
exchange: Option<String>,
#[serde(default)]
params: HashMap<String, serde_json::Value>,
}
fn default_tf() -> String {
"5m".to_string()
}
impl BatchCmd {
pub async fn run(self, api_url: &str, pretty: bool) -> Result<()> {
let base = canonical_base_url(api_url);
let content = std::fs::read_to_string(&self.file)
.with_context(|| format!("failed to read batch file: {}", self.file.display()))?;
let batch: BatchFile =
serde_yaml::from_str(&content).context("failed to parse batch YAML")?;
if batch.requests.is_empty() {
anyhow::bail!("batch file contains no requests");
}
if batch.requests.len() > 50 {
anyhow::bail!(
"batch file contains {} requests; maximum is 50",
batch.requests.len()
);
}
let (client, token) = build_client()?;
let payload = serde_json::json!({
"requests": batch.requests,
"format": self.format,
});
let url = format!("{base}/v1/batch");
let rb = client.post(&url).bearer_auth(&token).json(&payload);
let json = post_json(rb).await?;
print_json(&json, pretty)
}
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
#[test]
fn parse_batch_file_arg() {
let cmd = BatchCmd::try_parse_from(["batch", "signals.yaml"]).unwrap();
assert_eq!(cmd.file.to_str().unwrap(), "signals.yaml");
assert_eq!(cmd.format, "verbose");
}
#[test]
fn parse_batch_compact() {
let cmd =
BatchCmd::try_parse_from(["batch", "signals.yaml", "--format", "compact"]).unwrap();
assert_eq!(cmd.format, "compact");
}
#[test]
fn deserialize_batch_yaml() {
let yaml = r#"
requests:
- name: rsi
symbol: BTC-USD
tf: 5m
params:
period: 14
- name: macd
symbol: ETH-USD
tf: 1h
"#;
let b: BatchFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(b.requests.len(), 2);
assert_eq!(b.requests[0].name, "rsi");
assert_eq!(b.requests[1].tf, "1h");
}
#[test]
fn deserialize_batch_defaults_tf() {
let yaml = r#"
requests:
- name: rsi
symbol: BTC-USD
"#;
let b: BatchFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(b.requests[0].tf, "5m");
}
}