use anyhow::{Context, Result, bail};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorDimension {
pub id: String,
pub name: String,
#[serde(default)]
pub rune: Option<String>,
pub anchors: DimensionAnchors,
#[serde(default = "default_half")]
pub default: f32,
}
fn default_half() -> f32 {
0.5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DimensionAnchors {
pub low: String,
#[serde(default)]
pub mid: Option<String>,
pub high: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorMood {
pub description: String,
pub tensor: Vec<f32>,
#[serde(default)]
pub weights: Option<Vec<f32>>,
#[serde(default = "default_tolerance")]
pub tolerance: f32,
}
fn default_tolerance() -> f32 {
0.3
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSchema {
pub id: String,
pub name: String,
#[serde(default = "default_version")]
pub version: u32,
pub dimensions: Vec<TensorDimension>,
#[serde(default)]
pub moods: HashMap<String, TensorMood>,
}
fn default_version() -> u32 {
1
}
impl TensorSchema {
pub fn load(path: &Path) -> Result<Self> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read schema file: {:?}", path))?;
serde_yaml::from_str(&content)
.or_else(|_| serde_json::from_str(&content))
.with_context(|| format!("Failed to parse schema: {:?}", path))
}
pub fn load_by_id(schema_id: &str) -> Result<Self> {
if let Ok(schema_path) = std::env::var("MX_STATE_SCHEMA") {
let path = std::path::PathBuf::from(&schema_path);
if path.exists() {
let schema = Self::load(&path)?;
if schema.id == schema_id {
return Ok(schema);
}
}
}
let schemas_dir = crate::paths::schemas_dir();
let yaml_path = schemas_dir.join(format!("{}.yaml", schema_id));
if yaml_path.exists() {
return Self::load(&yaml_path);
}
let json_path = schemas_dir.join(format!("{}.json", schema_id));
if json_path.exists() {
return Self::load(&json_path);
}
bail!(
"Schema '{}' not found in {}",
schema_id,
schemas_dir.display()
)
}
pub fn load_default() -> Result<Self> {
if let Ok(schema_path) = std::env::var("MX_STATE_SCHEMA") {
let path = std::path::PathBuf::from(&schema_path);
if path.exists() {
return Self::load(&path);
}
}
Self::load_by_id("crewu")
}
pub fn list_available() -> Result<Vec<String>> {
let mut schemas = Vec::new();
let schemas_dir = crate::paths::schemas_dir();
if schemas_dir.exists() {
for entry in fs::read_dir(&schemas_dir)? {
let entry = entry?;
let path = entry.path();
if let Some(ext) = path.extension()
&& (ext == "yaml" || ext == "yml" || ext == "json")
&& let Some(stem) = path.file_stem()
{
schemas.push(stem.to_string_lossy().to_string());
}
}
}
schemas.sort();
schemas.dedup();
Ok(schemas)
}
pub fn dimension(&self, index: usize) -> Option<&TensorDimension> {
self.dimensions.get(index)
}
pub fn dimension_by_id(&self, id: &str) -> Option<(usize, &TensorDimension)> {
self.dimensions.iter().enumerate().find(|(_, d)| d.id == id)
}
pub fn dimension_count(&self) -> usize {
self.dimensions.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateTensor {
pub schema_id: String,
pub values: Vec<f32>,
}
impl StateTensor {
pub fn new(schema_id: String, values: Vec<f32>) -> Self {
Self { schema_id, values }
}
pub fn default_from_schema(schema: &TensorSchema) -> Self {
let values: Vec<f32> = schema.dimensions.iter().map(|d| d.default).collect();
Self {
schema_id: schema.id.clone(),
values,
}
}
pub fn parse_values(schema: &TensorSchema, input: &str) -> Result<Self> {
let parts: Vec<&str> = input.split('|').collect();
if parts.len() != schema.dimension_count() {
bail!(
"Expected {} values for schema '{}', got {}",
schema.dimension_count(),
schema.id,
parts.len()
);
}
let mut values = Vec::with_capacity(parts.len());
for (i, part) in parts.iter().enumerate() {
let dim = &schema.dimensions[i];
let value: f32 = part
.trim()
.parse()
.with_context(|| format!("Invalid value for dimension '{}': {}", dim.id, part))?;
values.push(value.clamp(0.0, 1.0));
}
Ok(Self {
schema_id: schema.id.clone(),
values,
})
}
pub fn parse_named_dimensions(schema: &TensorSchema, input: &str) -> Result<Self> {
use std::collections::HashMap;
let mut named_values: HashMap<String, f32> = HashMap::new();
for part in input.split_whitespace() {
let kv: Vec<&str> = part.split('=').collect();
if kv.len() != 2 {
bail!("Invalid dimension format '{}'. Expected 'name=value'", part);
}
let name = kv[0].trim().to_lowercase();
let value: f32 = kv[1]
.trim()
.parse()
.with_context(|| format!("Invalid value for dimension '{}': {}", name, kv[1]))?;
named_values.insert(name, value.clamp(0.0, 1.0));
}
let mut values = Vec::with_capacity(schema.dimensions.len());
for dim in &schema.dimensions {
let dim_id_lower = dim.id.to_lowercase();
let value = if let Some(&v) = named_values.get(&dim_id_lower) {
v
} else {
let prefix_match = named_values
.iter()
.find(|(k, _)| dim_id_lower.starts_with(k.as_str()))
.map(|(_, &v)| v);
match prefix_match {
Some(v) => v,
None => bail!(
"No value provided for dimension '{}'. Available: {}",
dim.id,
schema
.dimensions
.iter()
.map(|d| &d.id)
.cloned()
.collect::<Vec<_>>()
.join(", ")
),
}
};
values.push(value);
}
Ok(Self {
schema_id: schema.id.clone(),
values,
})
}
pub fn encode(&self) -> String {
let values_str: Vec<String> = self.values.iter().map(|v| format!("{:.2}", v)).collect();
format!("@state:{}|{}", self.schema_id, values_str.join("|"))
}
pub fn encode_with_runes(&self, schema: &TensorSchema) -> String {
let parts: Vec<String> = self
.values
.iter()
.enumerate()
.map(|(i, v)| {
let rune = schema
.dimensions
.get(i)
.and_then(|d| d.rune.as_ref())
.map(|r| r.as_str())
.unwrap_or("");
format!("{}{:.2}", rune, v)
})
.collect();
format!("@state:{}|{}", self.schema_id, parts.join("|"))
}
pub fn decode(input: &str) -> Result<Self> {
let input = input.trim();
if !input.starts_with("@state:") {
bail!("Invalid tensor format: must start with @state:");
}
let rest = &input[7..]; let parts: Vec<&str> = rest.split('|').collect();
if parts.is_empty() {
bail!("Invalid tensor format: missing schema ID");
}
let schema_id = parts[0].to_string();
let mut values = Vec::with_capacity(parts.len() - 1);
for part in parts.iter().skip(1) {
let value_str: String = part
.chars()
.skip_while(|c| !c.is_ascii_digit() && *c != '.' && *c != '-')
.collect();
let value: f32 = value_str
.parse()
.with_context(|| format!("Invalid value: {}", part))?;
values.push(value.clamp(0.0, 1.0));
}
Ok(Self { schema_id, values })
}
pub fn distance_to_mood(&self, mood: &TensorMood) -> f32 {
if self.values.len() != mood.tensor.len() {
return f32::MAX;
}
let weights = mood.weights.as_ref();
let sum: f32 = self
.values
.iter()
.zip(mood.tensor.iter())
.enumerate()
.map(|(i, (v1, v2))| {
let weight = weights.and_then(|w| w.get(i)).copied().unwrap_or(1.0);
weight * (v1 - v2).powi(2)
})
.sum();
sum.sqrt()
}
pub fn nearest_mood<'a>(
&self,
schema: &'a TensorSchema,
) -> Option<(&'a str, &'a TensorMood, f32)> {
let mut nearest: Option<(&str, &TensorMood, f32)> = None;
for (name, mood) in &schema.moods {
let distance = self.distance_to_mood(mood);
if distance <= mood.tolerance {
match &nearest {
None => nearest = Some((name.as_str(), mood, distance)),
Some((_, _, prev_dist)) if distance < *prev_dist => {
nearest = Some((name.as_str(), mood, distance));
}
_ => {}
}
}
}
nearest
}
pub fn describe(&self, schema: &TensorSchema) -> String {
let mut parts = Vec::new();
for (i, value) in self.values.iter().enumerate() {
if let Some(dim) = schema.dimensions.get(i) {
let anchor_desc = if *value < 0.33 {
&dim.anchors.low
} else if *value > 0.66 {
&dim.anchors.high
} else {
dim.anchors.mid.as_ref().unwrap_or(&dim.anchors.low)
};
parts.push(format!("{}: {:.2} ({})", dim.name, value, anchor_desc));
}
}
parts.join(", ")
}
pub fn format_bootstrap(&self, schema: &TensorSchema) -> Result<String> {
use std::fmt::Write;
let mut output = String::new();
writeln!(
&mut output,
"Wake State: {}",
self.encode_with_runes(schema)
)?;
let legend_parts: Vec<String> = schema
.dimensions
.iter()
.filter_map(|dim| dim.rune.as_ref().map(|rune| format!("{}={}", rune, dim.id)))
.collect();
if !legend_parts.is_empty() {
writeln!(&mut output, "({})", legend_parts.join(", "))?;
}
writeln!(&mut output)?;
let desc_parts: Vec<String> = self
.values
.iter()
.enumerate()
.filter_map(|(i, value)| {
schema.dimensions.get(i).map(|dim| {
let anchor_desc = self.interpolate_anchor_description(dim, *value);
format!("{} ({:.1})", anchor_desc, value)
})
})
.collect();
write!(&mut output, "{}.", desc_parts.join(", "))?;
Ok(output)
}
fn interpolate_anchor_description(&self, dim: &TensorDimension, value: f32) -> String {
let low = &dim.anchors.low;
let high = &dim.anchors.high;
if value < 0.33 {
low.clone()
} else if value > 0.66 {
high.clone()
} else {
if let Some(mid) = &dim.anchors.mid {
mid.clone()
} else {
if value > 0.5 {
format!("moderately {}", high)
} else {
format!("moderately {}", low)
}
}
}
}
pub fn get(&self, schema: &TensorSchema, dim_id: &str) -> Option<f32> {
schema
.dimension_by_id(dim_id)
.and_then(|(idx, _)| self.values.get(idx))
.copied()
}
pub fn set(&mut self, schema: &TensorSchema, dim_id: &str, value: f32) -> Result<()> {
let (idx, _) = schema
.dimension_by_id(dim_id)
.ok_or_else(|| anyhow::anyhow!("Unknown dimension: {}", dim_id))?;
if idx < self.values.len() {
self.values[idx] = value.clamp(0.0, 1.0);
}
Ok(())
}
}
pub fn guided_capture(schema: &TensorSchema) -> Result<StateTensor> {
use std::io::{self, Write};
println!("\n{} ({})\n", schema.name, schema.id);
println!("Enter values 0.0-1.0 for each dimension.\n");
let mut values = Vec::with_capacity(schema.dimensions.len());
for dim in &schema.dimensions {
let rune = dim
.rune
.as_ref()
.map(|r| format!("{} ", r))
.unwrap_or_default();
println!("{}{}:", rune, dim.name);
println!(" Low (0.0): {}", dim.anchors.low);
if let Some(mid) = &dim.anchors.mid {
println!(" Mid (0.5): {}", mid);
}
println!(" High (1.0): {}", dim.anchors.high);
println!(" Default: {:.2}", dim.default);
print!("> ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let input = input.trim();
let value: f32 = if input.is_empty() {
dim.default
} else {
input
.parse()
.with_context(|| format!("Invalid number: {}", input))?
};
values.push(value.clamp(0.0, 1.0));
println!();
}
Ok(StateTensor {
schema_id: schema.id.clone(),
values,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn test_schema() -> TensorSchema {
TensorSchema {
id: "test".to_string(),
name: "Test Schema".to_string(),
version: 1,
dimensions: vec![
TensorDimension {
id: "dim1".to_string(),
name: "Dimension 1".to_string(),
rune: Some("A".to_string()),
anchors: DimensionAnchors {
low: "low1".to_string(),
mid: Some("mid1".to_string()),
high: "high1".to_string(),
},
default: 0.5,
},
TensorDimension {
id: "dim2".to_string(),
name: "Dimension 2".to_string(),
rune: Some("B".to_string()),
anchors: DimensionAnchors {
low: "low2".to_string(),
mid: None,
high: "high2".to_string(),
},
default: 0.5,
},
],
moods: HashMap::from([
(
"calm".to_string(),
TensorMood {
description: "Calm state".to_string(),
tensor: vec![0.2, 0.3],
weights: Some(vec![1.0, 0.8]),
tolerance: 0.3,
},
),
(
"excited".to_string(),
TensorMood {
description: "Excited state".to_string(),
tensor: vec![0.8, 0.9],
weights: Some(vec![0.9, 1.0]),
tolerance: 0.3,
},
),
]),
}
}
#[test]
fn test_parse_values() {
let schema = test_schema();
let tensor = StateTensor::parse_values(&schema, "0.3|0.7").unwrap();
assert_eq!(tensor.schema_id, "test");
assert_eq!(tensor.values.len(), 2);
assert!((tensor.values[0] - 0.3).abs() < 0.01);
assert!((tensor.values[1] - 0.7).abs() < 0.01);
}
#[test]
fn test_encode_decode_roundtrip() {
let tensor = StateTensor::new("crewu".to_string(), vec![0.3, 0.2, 0.7, 0.8, 0.4]);
let encoded = tensor.encode();
assert!(encoded.starts_with("@state:crewu|"));
let decoded = StateTensor::decode(&encoded).unwrap();
assert_eq!(decoded.schema_id, "crewu");
assert_eq!(decoded.values.len(), 5);
for (a, b) in tensor.values.iter().zip(decoded.values.iter()) {
assert!((a - b).abs() < 0.01);
}
}
#[test]
fn test_nearest_mood() {
let schema = test_schema();
let tensor = StateTensor::new("test".to_string(), vec![0.25, 0.35]);
let (name, _, distance) = tensor.nearest_mood(&schema).unwrap();
assert_eq!(name, "calm");
assert!(distance < 0.3);
}
#[test]
fn test_distance_with_weights() {
let schema = test_schema();
let tensor1 = StateTensor::new("test".to_string(), vec![0.2, 0.5]);
let tensor2 = StateTensor::new("test".to_string(), vec![0.4, 0.3]);
let calm = schema.moods.get("calm").unwrap();
let dist1 = tensor1.distance_to_mood(calm);
let dist2 = tensor2.distance_to_mood(calm);
assert!(dist1 < dist2);
}
#[test]
fn test_parse_named_dimensions() {
let schema = test_schema();
let tensor = StateTensor::parse_named_dimensions(&schema, "dim1=0.3 dim2=0.7").unwrap();
assert_eq!(tensor.schema_id, "test");
assert_eq!(tensor.values.len(), 2);
assert!((tensor.values[0] - 0.3).abs() < 0.01);
assert!((tensor.values[1] - 0.7).abs() < 0.01);
let tensor2 = StateTensor::parse_named_dimensions(&schema, "DIM1=0.4 DIM2=0.8").unwrap();
assert!((tensor2.values[0] - 0.4).abs() < 0.01);
assert!((tensor2.values[1] - 0.8).abs() < 0.01);
let tensor3 = StateTensor::parse_named_dimensions(&schema, "d=0.5 dim2=0.6").unwrap();
assert!((tensor3.values[0] - 0.5).abs() < 0.01);
assert!((tensor3.values[1] - 0.6).abs() < 0.01);
}
#[test]
fn test_parse_named_dimensions_missing() {
let schema = test_schema();
let result = StateTensor::parse_named_dimensions(&schema, "dim1=0.3");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("No value provided for dimension 'dim2'")
);
}
#[test]
fn test_format_bootstrap() {
let schema = test_schema();
let tensor = StateTensor::new("test".to_string(), vec![0.8, 0.2]);
let output = tensor.format_bootstrap(&schema).unwrap();
assert!(output.contains("Wake State:"));
assert!(output.contains("@state:test"));
assert!(output.contains("A=dim1"));
assert!(output.contains("B=dim2"));
assert!(output.contains("high1"));
assert!(output.contains("low2"));
}
#[test]
fn test_interpolate_anchor_description() {
let dim = TensorDimension {
id: "test".to_string(),
name: "Test".to_string(),
rune: None,
anchors: DimensionAnchors {
low: "cold".to_string(),
mid: Some("balanced".to_string()),
high: "hot".to_string(),
},
default: 0.5,
};
let tensor = StateTensor::new("test".to_string(), vec![0.0]);
assert_eq!(tensor.interpolate_anchor_description(&dim, 0.2), "cold");
assert_eq!(tensor.interpolate_anchor_description(&dim, 0.8), "hot");
assert_eq!(tensor.interpolate_anchor_description(&dim, 0.5), "balanced");
let dim_no_mid = TensorDimension {
id: "test".to_string(),
name: "Test".to_string(),
rune: None,
anchors: DimensionAnchors {
low: "cold".to_string(),
mid: None,
high: "hot".to_string(),
},
default: 0.5,
};
assert_eq!(
tensor.interpolate_anchor_description(&dim_no_mid, 0.6),
"moderately hot"
);
assert_eq!(
tensor.interpolate_anchor_description(&dim_no_mid, 0.4),
"moderately cold"
);
}
}