use std::thread;
use std::time::Duration;
use anyhow::{Context, Result, bail};
use atm_core::observability::{
AtmJsonNumber, AtmLogQuery, LogFieldKey, LogFieldMatch, LogFieldValue, LogLevelFilter, LogMode,
LogOrder, ObservabilityPort,
};
use atm_core::types::IsoTimestamp;
use clap::{Args, Subcommand, ValueEnum};
use crate::observability::CliObservability;
use crate::output;
const DEFAULT_SNAPSHOT_LIMIT: usize = 50;
const DEFAULT_TAIL_POLL_INTERVAL_MS: u64 = 250;
#[derive(Debug, Args)]
pub struct LogCommand {
#[command(subcommand)]
mode: LogModeCommand,
}
impl LogCommand {
pub fn run(self, observability: &CliObservability) -> Result<()> {
match self.mode {
LogModeCommand::Snapshot(args) => {
let snapshot = observability.query(args.build_query(LogMode::Snapshot)?)?;
output::print_log_snapshot(&snapshot, args.json)
}
LogModeCommand::Filter(args) => {
args.ensure_filter_present()?;
let snapshot = observability.query(args.build_query(LogMode::Snapshot)?)?;
output::print_log_snapshot(&snapshot, args.json)
}
LogModeCommand::Tail(args) => args.run(observability),
}
}
}
#[derive(Debug, Subcommand)]
enum LogModeCommand {
Snapshot(QueryArgs),
Filter(QueryArgs),
Tail(TailArgs),
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum CliLogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
impl From<CliLogLevel> for LogLevelFilter {
fn from(value: CliLogLevel) -> Self {
match value {
CliLogLevel::Trace => LogLevelFilter::Trace,
CliLogLevel::Debug => LogLevelFilter::Debug,
CliLogLevel::Info => LogLevelFilter::Info,
CliLogLevel::Warn => LogLevelFilter::Warn,
CliLogLevel::Error => LogLevelFilter::Error,
}
}
}
#[derive(Debug, Args)]
struct QueryArgs {
#[arg(long = "level", value_enum)]
levels: Vec<CliLogLevel>,
#[arg(long = "match", value_name = "KEY=VALUE")]
matches: Vec<String>,
#[arg(long)]
since: Option<String>,
#[arg(long)]
limit: Option<usize>,
#[arg(long)]
json: bool,
}
impl QueryArgs {
fn build_query(&self, mode: LogMode) -> Result<AtmLogQuery> {
let limit = match mode {
LogMode::Snapshot => Some(self.limit.unwrap_or(DEFAULT_SNAPSHOT_LIMIT)),
LogMode::Tail => self.limit,
};
Ok(AtmLogQuery {
mode,
levels: self.levels.iter().copied().map(Into::into).collect(),
field_matches: self
.matches
.iter()
.map(|raw| parse_match_expression(raw))
.collect::<Result<Vec<_>>>()?,
since: self.since.as_deref().map(parse_since).transpose()?,
until: None,
limit,
order: LogOrder::NewestFirst,
})
}
fn ensure_filter_present(&self) -> Result<()> {
if self.matches.is_empty() && self.levels.is_empty() && self.since.is_none() {
bail!("atm log filter requires at least one of --match, --level, or --since");
}
Ok(())
}
}
#[derive(Debug, Args)]
struct TailArgs {
#[command(flatten)]
query: QueryArgs,
#[arg(long, default_value_t = DEFAULT_TAIL_POLL_INTERVAL_MS)]
poll_interval_ms: u64,
#[cfg(test)]
#[arg(long, hide = true)]
max_polls: Option<usize>,
}
impl TailArgs {
#[cfg(not(test))]
fn run(self, observability: &CliObservability) -> Result<()> {
let mut session = observability.follow(self.query.build_query(LogMode::Tail)?)?;
loop {
let snapshot = session.poll()?;
output::print_log_records(snapshot.records, self.query.json)?;
thread::sleep(Duration::from_millis(self.poll_interval_ms));
}
}
#[cfg(test)]
fn run(self, observability: &CliObservability) -> Result<()> {
let mut session = observability.follow(self.query.build_query(LogMode::Tail)?)?;
let mut polls = 0usize;
loop {
let snapshot = session.poll()?;
output::print_log_records(snapshot.records, self.query.json)?;
polls += 1;
if self.max_polls.is_some_and(|limit| polls >= limit) {
return Ok(());
}
thread::sleep(Duration::from_millis(self.poll_interval_ms));
}
}
}
fn parse_match_expression(raw: &str) -> Result<LogFieldMatch> {
let (key, value) = raw
.split_once('=')
.ok_or_else(|| anyhow::anyhow!("invalid --match expression '{raw}'; expected key=value"))?;
if key.trim().is_empty() {
bail!("invalid --match expression '{raw}'; key must not be empty");
}
Ok(LogFieldMatch {
key: LogFieldKey::new(key.to_string())?,
value: parse_match_value(value),
})
}
fn parse_match_value(raw: &str) -> LogFieldValue {
if raw.eq_ignore_ascii_case("true") {
LogFieldValue::bool(true)
} else if raw.eq_ignore_ascii_case("false") {
LogFieldValue::bool(false)
} else if raw.eq_ignore_ascii_case("null") {
LogFieldValue::null()
} else if let Ok(number) = AtmJsonNumber::new(raw.to_string()) {
LogFieldValue::number(number)
} else {
LogFieldValue::string(raw.to_string())
}
}
fn parse_since(raw: &str) -> Result<IsoTimestamp> {
parse_rfc3339(raw).or_else(|_| parse_relative_duration(raw))
}
fn parse_rfc3339(raw: &str) -> Result<IsoTimestamp> {
chrono::DateTime::parse_from_rfc3339(raw)
.with_context(|| format!("invalid RFC3339 timestamp: {raw}"))
.map(|timestamp| timestamp.with_timezone(&chrono::Utc).into())
}
fn parse_relative_duration(raw: &str) -> Result<IsoTimestamp> {
if raw.len() < 2 {
bail!("invalid relative duration '{raw}'; expected forms like 30s, 15m, 2h, or 7d");
}
let (amount, unit) = raw
.char_indices()
.next_back()
.map(|(index, _)| (&raw[..index], &raw[index..]))
.ok_or_else(|| {
anyhow::anyhow!(
"invalid relative duration '{raw}'; expected forms like 30s, 15m, 2h, or 7d"
)
})?;
let amount: i64 = amount.parse().with_context(|| {
format!("invalid relative duration '{raw}'; duration amount must be an integer")
})?;
let delta = match unit {
"s" => chrono::Duration::seconds(amount),
"m" => chrono::Duration::minutes(amount),
"h" => chrono::Duration::hours(amount),
"d" => chrono::Duration::days(amount),
_ => bail!("invalid relative duration '{raw}'; supported units are s, m, h, d"),
};
Ok((chrono::Utc::now() - delta).into())
}
#[cfg(test)]
mod tests {
use super::parse_relative_duration;
#[test]
fn parse_relative_duration_rejects_multibyte_suffix_without_panicking() {
let error = parse_relative_duration("10ยต").expect_err("invalid unit");
assert!(
error.to_string().contains("supported units are s, m, h, d"),
"error: {error}"
);
}
}