use std::str::FromStr;
use kiromi_ai_memory::{AttributeValue, MemoryId, MemoryRef};
use rust_decimal::Decimal;
use crate::cli::{
AttrClearArgs, AttrFindArgs, AttrFindRangeArgs, AttrGetArgs, AttrListArgs, AttrSetArgs,
AttributeCmd, GlobalArgs,
};
use crate::cmd::get::probe_partition;
use crate::error::{CliError, ExitCode};
use crate::output;
use crate::runtime::Runtime;
pub(crate) async fn run(cmd: AttributeCmd, globals: &GlobalArgs) -> Result<(), CliError> {
let rt = Runtime::open(globals).await?;
let result = match cmd {
AttributeCmd::Set(a) => run_set(a, &rt, globals).await,
AttributeCmd::Get(a) => run_get(a, &rt, globals).await,
AttributeCmd::Clear(a) => run_clear(a, &rt, globals).await,
AttributeCmd::List(a) => run_list(a, &rt, globals).await,
AttributeCmd::Find(a) => run_find(a, &rt, globals).await,
AttributeCmd::FindRange(a) => run_find_range(a, &rt, globals).await,
};
rt.mem.close().await?;
result
}
fn parse_value(kind: &str, raw: &str) -> Result<AttributeValue, CliError> {
match kind {
"string" => Ok(AttributeValue::String(raw.to_string())),
"int" => raw
.parse::<i64>()
.map(AttributeValue::Int)
.map_err(|e| CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("invalid int: {e}"),
}),
"decimal" => Decimal::from_str(raw)
.map(AttributeValue::Decimal)
.map_err(|e| CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("invalid decimal: {e}"),
}),
"bool" => match raw {
"true" | "1" => Ok(AttributeValue::Bool(true)),
"false" | "0" => Ok(AttributeValue::Bool(false)),
_ => Err(CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("invalid bool: {raw}"),
}),
},
"timestamp" => raw
.parse::<i64>()
.map(AttributeValue::Timestamp)
.map_err(|e| CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("invalid timestamp (millis): {e}"),
}),
"array" => serde_json::from_str::<AttributeValue>(raw).or_else(|_| {
serde_json::from_str::<Vec<AttributeValue>>(raw)
.map(AttributeValue::Array)
.map_err(|e| CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("invalid array (JSON expected): {e}"),
})
}),
other => Err(CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("unknown kind: {other}"),
}),
}
}
async fn make_memory_ref(id: &str, rt: &Runtime) -> Result<MemoryRef, CliError> {
let id = MemoryId::from_str(id).map_err(|e| CliError {
kind: ExitCode::Config,
source: anyhow::anyhow!("bad memory id: {e}"),
})?;
let probe = MemoryRef {
id,
partition: probe_partition()?,
};
let rec = rt.mem.get(&probe).await?;
Ok(rec.r#ref)
}
async fn run_set(a: AttrSetArgs, rt: &Runtime, globals: &GlobalArgs) -> Result<(), CliError> {
let r = make_memory_ref(&a.memory, rt).await?;
let v = parse_value(&a.kind.kind, &a.value)?;
rt.mem.set_attribute(&r, &a.key, v).await?;
if globals.json {
println!("{}", output::to_json(&serde_json::json!({"set": true})));
} else {
println!("ok");
}
Ok(())
}
async fn run_get(a: AttrGetArgs, rt: &Runtime, globals: &GlobalArgs) -> Result<(), CliError> {
let r = make_memory_ref(&a.memory, rt).await?;
let v = rt.mem.get_attribute(&r, &a.key).await?;
if globals.json {
println!("{}", output::to_json(&serde_json::json!(v)));
} else {
match v {
Some(v) => println!("{v}"),
None => println!("(absent)"),
}
}
Ok(())
}
async fn run_clear(a: AttrClearArgs, rt: &Runtime, globals: &GlobalArgs) -> Result<(), CliError> {
let r = make_memory_ref(&a.memory, rt).await?;
rt.mem.clear_attribute(&r, &a.key).await?;
if globals.json {
println!("{}", output::to_json(&serde_json::json!({"cleared": true})));
} else {
println!("ok");
}
Ok(())
}
async fn run_list(a: AttrListArgs, rt: &Runtime, globals: &GlobalArgs) -> Result<(), CliError> {
let r = make_memory_ref(&a.memory, rt).await?;
let attrs = rt.mem.attributes_of(&r).await?;
if globals.json {
println!("{}", output::to_json(&serde_json::json!(attrs)));
} else {
for (k, v) in &attrs {
println!("{k}\t{}\t{v}", v.kind_str());
}
}
Ok(())
}
async fn run_find(a: AttrFindArgs, rt: &Runtime, globals: &GlobalArgs) -> Result<(), CliError> {
let v = parse_value(&a.kind.kind, &a.value)?;
let hits = rt.mem.find_by_attribute(&a.key, &v).await?;
let ids: Vec<String> = hits.iter().map(|m| m.id.to_string()).collect();
if globals.json {
println!("{}", output::to_json(&serde_json::json!(ids)));
} else {
for id in &ids {
println!("{id}");
}
}
Ok(())
}
async fn run_find_range(
a: AttrFindRangeArgs,
rt: &Runtime,
globals: &GlobalArgs,
) -> Result<(), CliError> {
let lo = parse_value(&a.kind, &a.min)?;
let hi = parse_value(&a.kind, &a.max)?;
let hits = rt.mem.find_by_attribute_range(&a.key, &lo, &hi).await?;
let ids: Vec<String> = hits.iter().map(|m| m.id.to_string()).collect();
if globals.json {
println!("{}", output::to_json(&serde_json::json!(ids)));
} else {
for id in &ids {
println!("{id}");
}
}
Ok(())
}