use serde::Deserialize;
use crate::step::{StepAction, TaskStep};
use super::DecompositionError;
#[derive(Debug, Deserialize)]
pub(super) struct RawStep {
#[serde(default, deserialize_with = "lenient_required_string")]
pub(super) description: String,
#[serde(default, deserialize_with = "lenient_required_string")]
pub(super) action_type: String,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) command: Option<String>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) query: Option<String>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) spec: Option<String>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) agent: Option<String>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) artifact: Option<String>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) channel: Option<String>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) message: Option<String>,
#[serde(default, deserialize_with = "lenient_usize_vec")]
pub(super) depends_on: Vec<usize>,
#[serde(default, deserialize_with = "lenient_optional_string")]
pub(super) tier: Option<String>,
#[serde(default, deserialize_with = "null_to_default")]
pub(super) estimated_tokens: Option<u64>,
}
fn null_to_default<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: Default + Deserialize<'de>,
D: serde::Deserializer<'de>,
{
let opt = Option::<T>::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
}
fn lenient_optional_string<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
use std::fmt;
struct V;
impl<'de> Visitor<'de> for V {
type Value = Option<String>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("string, integer, float, bool, or null")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(if v.is_empty() {
None
} else {
Some(v.to_string())
})
}
fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
Ok(if v.is_empty() { None } else { Some(v) })
}
fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}
fn visit_some<D: serde::Deserializer<'de>>(self, d: D) -> Result<Self::Value, D::Error> {
d.deserialize_any(self)
}
fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(Some(v.to_string()))
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(Some(v.to_string()))
}
fn visit_f64<E: de::Error>(self, v: f64) -> Result<Self::Value, E> {
Ok(Some(v.to_string()))
}
fn visit_bool<E: de::Error>(self, v: bool) -> Result<Self::Value, E> {
Ok(Some(v.to_string()))
}
}
deserializer.deserialize_any(V)
}
fn lenient_required_string<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(lenient_optional_string(deserializer)?.unwrap_or_default())
}
fn lenient_usize_vec<'de, D>(deserializer: D) -> Result<Vec<usize>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, SeqAccess, Visitor};
use std::fmt;
struct V;
impl<'de> Visitor<'de> for V {
type Value = Vec<usize>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("array of indices, single index, or null")
}
fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(Vec::new())
}
fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(Vec::new())
}
fn visit_some<D: serde::Deserializer<'de>>(self, d: D) -> Result<Self::Value, D::Error> {
d.deserialize_any(self)
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(vec![v as usize])
}
fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
if v < 0 {
Ok(Vec::new())
} else {
Ok(vec![v as usize])
}
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut out = Vec::new();
while let Some(elem) = seq.next_element::<serde_json::Value>()? {
if let Some(n) = elem.as_u64() {
out.push(n as usize);
} else if let Some(n) = elem.as_i64() {
if n >= 0 {
out.push(n as usize);
}
}
}
Ok(out)
}
}
deserializer.deserialize_any(V)
}
pub(super) fn build_task_step(i: usize, raw: RawStep, ids: &[String]) -> TaskStep {
let depends_on: Vec<String> = raw
.depends_on
.iter()
.filter_map(|&idx| ids.get(idx).cloned())
.collect();
let action = match raw.action_type.as_str() {
"research" => StepAction::Research {
query: raw.query.unwrap_or_else(|| raw.description.clone()),
},
"plan" => StepAction::Plan {
output: raw.spec.unwrap_or_default(),
},
"implement" => StepAction::Implement {
spec: raw.spec.unwrap_or_else(|| raw.description.clone()),
agent: raw.agent.unwrap_or_else(|| "default".to_string()),
},
"execute" => StepAction::Execute {
command: raw.command.unwrap_or_default(),
workdir: std::env::current_dir().unwrap_or_default(),
},
"test" => StepAction::Test {
command: raw.command.unwrap_or_else(|| "cargo test".to_string()),
workdir: std::env::current_dir().unwrap_or_default(),
},
"shell" => StepAction::Shell {
command: raw.command.unwrap_or_default(),
workdir: std::env::current_dir().unwrap_or_default(),
},
"review" => StepAction::Review {
artifact: raw.artifact.unwrap_or_else(|| raw.description.clone()),
},
"notify" => StepAction::Notify {
channel: raw.channel.unwrap_or_else(|| "default".to_string()),
message: raw.message.unwrap_or_else(|| raw.description.clone()),
},
_ => StepAction::Plan {
output: raw.description.clone(),
},
};
let tier = match raw.tier.as_deref() {
Some("read") => audit::ActionTier::Read,
Some("write") => audit::ActionTier::Write,
Some("destructive") => audit::ActionTier::Destructive,
Some("external") => audit::ActionTier::External,
_ => audit::ActionTier::Execute,
};
let tier = match (&action, tier) {
(StepAction::Notify { .. }, audit::ActionTier::External) => audit::ActionTier::Read,
(_, t) => t,
};
TaskStep {
id: ids[i].clone(),
description: raw.description,
action,
depends_on,
tier,
estimated_tokens: raw.estimated_tokens.unwrap_or(0),
}
}
pub(super) fn parse_steps(raw: &str) -> Result<Vec<RawStep>, DecompositionError> {
let trimmed = raw.trim();
let json_str = if let Some(start) = trimmed.find('[') {
if let Some(end) = trimmed.rfind(']') {
&trimmed[start..=end]
} else {
trimmed
}
} else {
trimmed
};
serde_json::from_str(json_str).map_err(|e| DecompositionError::Parse(e.to_string()))
}