use std::collections::HashMap;
use crate::{ResponsePath, ResponseValue};
#[derive(Debug, thiserror::Error)]
pub enum ResponseError {
#[error("Missing response for path: {0}")]
MissingPath(ResponsePath),
#[error("Type mismatch at path '{path}': expected {expected}, got {actual}")]
TypeMismatch {
path: ResponsePath,
expected: &'static str,
actual: &'static str,
},
}
#[derive(Debug, Clone, Default)]
pub struct Responses {
values: HashMap<ResponsePath, ResponseValue>,
}
impl Responses {
pub fn new() -> Self {
Self {
values: HashMap::new(),
}
}
pub fn insert(&mut self, path: impl Into<ResponsePath>, value: impl Into<ResponseValue>) {
self.values.insert(path.into(), value.into());
}
pub fn get(&self, path: &ResponsePath) -> Option<&ResponseValue> {
self.values.get(path)
}
pub fn contains(&self, path: &ResponsePath) -> bool {
self.values.contains_key(path)
}
pub fn remove(&mut self, path: &ResponsePath) -> Option<ResponseValue> {
self.values.remove(path)
}
pub fn iter(&self) -> impl Iterator<Item = (&ResponsePath, &ResponseValue)> {
self.values.iter()
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn extend(&mut self, other: Responses) {
self.values.extend(other.values);
}
pub fn filter_prefix(&self, prefix: &ResponsePath) -> Self {
let mut filtered = Responses::new();
for (path, value) in &self.values {
if let Some(stripped) = path.strip_path_prefix(prefix) {
filtered.values.insert(stripped, value.clone());
}
}
filtered
}
pub fn get_string(&self, path: &ResponsePath) -> Result<&str, ResponseError> {
match self.get(path) {
Some(ResponseValue::String(s)) => Ok(s),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "String",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_int(&self, path: &ResponsePath) -> Result<i64, ResponseError> {
match self.get(path) {
Some(ResponseValue::Int(i)) => Ok(*i),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "Int",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_float(&self, path: &ResponsePath) -> Result<f64, ResponseError> {
match self.get(path) {
Some(ResponseValue::Float(f)) => Ok(*f),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "Float",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_bool(&self, path: &ResponsePath) -> Result<bool, ResponseError> {
match self.get(path) {
Some(ResponseValue::Bool(b)) => Ok(*b),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "Bool",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_chosen_variant(&self, path: &ResponsePath) -> Result<usize, ResponseError> {
match self.get(path) {
Some(ResponseValue::ChosenVariant(idx)) => Ok(*idx),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "ChosenVariant",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_chosen_variants(&self, path: &ResponsePath) -> Result<&[usize], ResponseError> {
match self.get(path) {
Some(ResponseValue::ChosenVariants(indices)) => Ok(indices),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "ChosenVariants",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_string_list(&self, path: &ResponsePath) -> Result<&[String], ResponseError> {
match self.get(path) {
Some(ResponseValue::StringList(list)) => Ok(list),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "StringList",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_int_list(&self, path: &ResponsePath) -> Result<&[i64], ResponseError> {
match self.get(path) {
Some(ResponseValue::IntList(list)) => Ok(list),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "IntList",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn get_float_list(&self, path: &ResponsePath) -> Result<&[f64], ResponseError> {
match self.get(path) {
Some(ResponseValue::FloatList(list)) => Ok(list),
Some(other) => Err(ResponseError::TypeMismatch {
path: path.clone(),
expected: "FloatList",
actual: other.type_name(),
}),
None => Err(ResponseError::MissingPath(path.clone())),
}
}
pub fn has_value(&self, path: &ResponsePath) -> bool {
match self.get(path) {
Some(ResponseValue::String(s)) => !s.is_empty(),
Some(_) => true,
None => false,
}
}
}
impl IntoIterator for Responses {
type Item = (ResponsePath, ResponseValue);
type IntoIter = std::collections::hash_map::IntoIter<ResponsePath, ResponseValue>;
fn into_iter(self) -> Self::IntoIter {
self.values.into_iter()
}
}
impl<'a> IntoIterator for &'a Responses {
type Item = (&'a ResponsePath, &'a ResponseValue);
type IntoIter = std::collections::hash_map::Iter<'a, ResponsePath, ResponseValue>;
fn into_iter(self) -> Self::IntoIter {
self.values.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_and_get() {
let mut responses = Responses::new();
responses.insert("name", "Alice");
responses.insert("age", ResponseValue::Int(30));
assert_eq!(
responses.get_string(&ResponsePath::new("name")).unwrap(),
"Alice"
);
assert_eq!(responses.get_int(&ResponsePath::new("age")).unwrap(), 30);
}
#[test]
fn filter_prefix() {
let mut responses = Responses::new();
responses.insert("address.street", "123 Main St");
responses.insert("address.city", "Springfield");
responses.insert("name", "Alice");
let filtered = responses.filter_prefix(&ResponsePath::new("address"));
assert_eq!(filtered.len(), 2);
assert_eq!(
filtered.get_string(&ResponsePath::new("street")).unwrap(),
"123 Main St"
);
assert_eq!(
filtered.get_string(&ResponsePath::new("city")).unwrap(),
"Springfield"
);
}
#[test]
fn type_mismatch_error() {
let mut responses = Responses::new();
responses.insert("age", ResponseValue::Int(30));
let result = responses.get_string(&ResponsePath::new("age"));
assert!(matches!(result, Err(ResponseError::TypeMismatch { .. })));
}
}