use crate::intern::Sym;
use crate::predict::FieldRange;
use crate::recursive::validate::{Score, Validate};
use crate::str_view::StrView;
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum ValueKind {
Str = 0,
Int = 1,
Float = 2,
Bool = 3,
JsonObject = 4,
JsonArray = 5,
Enum = 6,
}
impl ValueKind {
pub const fn label(self) -> &'static str {
match self {
Self::Str => "string",
Self::Int => "integer",
Self::Float => "float",
Self::Bool => "boolean",
Self::JsonObject => "JSON object",
Self::JsonArray => "JSON array",
Self::Enum => "enum",
}
}
pub fn matches(self, text: &str) -> bool {
let trimmed = text.trim();
match self {
Self::Str => true,
Self::Int => trimmed.parse::<i64>().is_ok(),
Self::Float => trimmed.parse::<f64>().is_ok(),
Self::Bool => matches!(
trimmed.to_ascii_lowercase().as_str(),
"true" | "false" | "yes" | "no" | "1" | "0"
),
Self::JsonObject => trimmed.starts_with('{') && trimmed.ends_with('}'),
Self::JsonArray => trimmed.starts_with('[') && trimmed.ends_with(']'),
Self::Enum => !trimmed.is_empty(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum Direction {
Input = 0,
Output = 1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TypedField {
pub name: Sym,
pub kind: ValueKind,
pub required: bool,
pub direction: Direction,
}
impl TypedField {
#[inline]
pub const fn new(name: Sym, kind: ValueKind, direction: Direction) -> Self {
Self {
name,
kind,
required: true,
direction,
}
}
const fn empty() -> Self {
Self {
name: Sym::EMPTY,
kind: ValueKind::Str,
required: false,
direction: Direction::Input,
}
}
#[inline]
pub const fn optional(mut self) -> Self {
self.required = false;
self
}
}
const MAX_FIELDS: usize = 4;
#[derive(Debug, Clone, Copy)]
pub struct TypedSignature {
instruction: &'static str,
inputs: [TypedField; MAX_FIELDS],
input_count: u8,
outputs: [TypedField; MAX_FIELDS],
output_count: u8,
}
impl TypedSignature {
#[inline]
pub const fn new(instruction: &'static str) -> Self {
Self {
instruction,
inputs: [TypedField::empty(); MAX_FIELDS],
input_count: 0,
outputs: [TypedField::empty(); MAX_FIELDS],
output_count: 0,
}
}
#[inline]
pub const fn input(mut self, name: Sym, kind: ValueKind) -> Self {
assert!(
(self.input_count as usize) < MAX_FIELDS,
"TypedSignature: exceeded maximum number of input fields"
);
self.inputs[self.input_count as usize] = TypedField::new(name, kind, Direction::Input);
self.input_count += 1;
self
}
#[inline]
pub const fn input_optional(mut self, name: Sym, kind: ValueKind) -> Self {
assert!(
(self.input_count as usize) < MAX_FIELDS,
"TypedSignature: exceeded maximum number of input fields"
);
let mut field = TypedField::new(name, kind, Direction::Input);
field.required = false;
self.inputs[self.input_count as usize] = field;
self.input_count += 1;
self
}
#[inline]
pub const fn output(mut self, name: Sym, kind: ValueKind) -> Self {
assert!(
(self.output_count as usize) < MAX_FIELDS,
"TypedSignature: exceeded maximum number of output fields"
);
self.outputs[self.output_count as usize] = TypedField::new(name, kind, Direction::Output);
self.output_count += 1;
self
}
#[inline]
pub const fn output_optional(mut self, name: Sym, kind: ValueKind) -> Self {
assert!(
(self.output_count as usize) < MAX_FIELDS,
"TypedSignature: exceeded maximum number of output fields"
);
let mut field = TypedField::new(name, kind, Direction::Output);
field.required = false;
self.outputs[self.output_count as usize] = field;
self.output_count += 1;
self
}
#[inline]
pub const fn instruction(&self) -> &'static str {
self.instruction
}
#[inline]
pub const fn input_count(&self) -> usize {
self.input_count as usize
}
#[inline]
pub const fn output_count(&self) -> usize {
self.output_count as usize
}
#[inline]
pub fn inputs(&self) -> &[TypedField] {
&self.inputs[..self.input_count as usize]
}
#[inline]
pub fn outputs(&self) -> &[TypedField] {
&self.outputs[..self.output_count as usize]
}
pub fn get_input(&self, name: Sym) -> Option<&TypedField> {
self.inputs().iter().find(|f| f.name == name)
}
pub fn get_output(&self, name: Sym) -> Option<&TypedField> {
self.outputs().iter().find(|f| f.name == name)
}
pub fn validator(&self) -> TypedFieldValidator {
TypedFieldValidator { sig: *self }
}
}
#[derive(Debug, Clone)]
pub struct ParsedOutput<'a> {
raw: StrView<'a>,
fields: SmallVec<[(Sym, FieldRange); 4]>,
}
impl<'a> ParsedOutput<'a> {
pub fn new(raw: StrView<'a>) -> Self {
Self {
raw,
fields: SmallVec::new(),
}
}
pub fn with_fields(raw: StrView<'a>, fields: SmallVec<[(Sym, FieldRange); 4]>) -> Self {
Self { raw, fields }
}
pub fn push(&mut self, name: Sym, range: FieldRange) {
self.fields.push((name, range));
}
#[inline]
pub fn raw(&self) -> StrView<'a> {
self.raw
}
pub fn iter(&self) -> impl Iterator<Item = &(Sym, FieldRange)> {
self.fields.iter()
}
#[inline]
pub fn field_count(&self) -> usize {
self.fields.len()
}
pub fn get_raw(&self, name: Sym) -> Option<&'a str> {
for (sym, fr) in &self.fields {
if *sym == name {
return self.raw.try_slice(fr.as_range()).map(|v| v.as_str());
}
}
None
}
pub fn get_str(&self, name: Sym) -> Option<&'a str> {
self.get_raw(name).map(|s| s.trim())
}
pub fn get_int(&self, name: Sym) -> Option<i64> {
self.get_str(name).and_then(|s| s.parse().ok())
}
pub fn get_float(&self, name: Sym) -> Option<f64> {
self.get_str(name).and_then(|s| s.parse().ok())
}
pub fn get_bool(&self, name: Sym) -> Option<bool> {
self.get_str(name)
.and_then(|s| match s.to_ascii_lowercase().as_str() {
"true" | "yes" | "1" => Some(true),
"false" | "no" | "0" => Some(false),
_ => None,
})
}
pub fn parse_from_sig(raw: StrView<'a>, sig: &TypedSignature) -> Self {
let text = raw.as_str();
let mut out = Self::new(raw);
for field in sig.outputs() {
let field_name = field.name.as_str();
let mut prefix = String::with_capacity(field_name.len() + 2);
let mut chars = field_name.chars();
if let Some(first) = chars.next() {
prefix.push(first.to_ascii_uppercase());
prefix.extend(chars);
}
prefix.push(':');
if let Some(idx) = text.find(&prefix) {
let mut start = idx + prefix.len();
while start < text.len() && text.as_bytes()[start] == b' ' {
start += 1;
}
let end = text[start..]
.find('\n')
.map(|i| start + i)
.unwrap_or(text.len());
out.push(field.name, FieldRange::new(start as u32, end as u32));
}
}
out
}
}
#[derive(Debug, Clone, Copy)]
pub struct TypedFieldValidator {
sig: TypedSignature,
}
impl TypedFieldValidator {
pub const fn new(sig: TypedSignature) -> Self {
Self { sig }
}
pub const fn signature(&self) -> &TypedSignature {
&self.sig
}
}
impl Validate for TypedFieldValidator {
fn validate(&self, text: &str) -> Score<'static> {
let view = StrView::new(text);
let parsed = ParsedOutput::parse_from_sig(view, &self.sig);
let outputs = self.sig.outputs();
if outputs.is_empty() {
return Score::pass();
}
let mut passed = 0usize;
let mut total_required = 0usize;
let mut feedback_parts: SmallVec<[String; 4]> = SmallVec::new();
for field in outputs {
if !field.required {
continue;
}
total_required += 1;
match parsed.get_raw(field.name) {
Some(raw_value) => {
if field.kind.matches(raw_value) {
passed += 1;
} else {
feedback_parts.push(format!(
"Field '{}': expected {}, got {:?}",
field.name.as_str(),
field.kind.label(),
raw_value.trim(),
));
}
}
None => {
feedback_parts
.push(format!("Missing required field '{}'", field.name.as_str(),));
}
}
}
if total_required == 0 {
return Score::pass();
}
let value = passed as f64 / total_required as f64;
if feedback_parts.is_empty() {
Score::pass()
} else {
Score::with_feedback(value, feedback_parts.join("; "))
}
}
fn name(&self) -> &'static str {
"typed_field_validator"
}
}
#[derive(Debug, Clone)]
pub struct TypedDemo<'a> {
pub inputs: SmallVec<[(Sym, &'a str); 4]>,
pub outputs: SmallVec<[(Sym, &'a str); 4]>,
}
impl<'a> TypedDemo<'a> {
pub fn new() -> Self {
Self {
inputs: SmallVec::new(),
outputs: SmallVec::new(),
}
}
pub fn input(mut self, name: Sym, value: &'a str) -> Self {
self.inputs.push((name, value));
self
}
pub fn output(mut self, name: Sym, value: &'a str) -> Self {
self.outputs.push((name, value));
self
}
pub fn format(&self) -> String {
let mut buf = String::new();
for (sym, val) in &self.inputs {
let name = sym.as_str();
let mut chars = name.chars();
if let Some(first) = chars.next() {
buf.push(first.to_ascii_uppercase());
buf.extend(chars);
}
buf.push_str(": ");
buf.push_str(val);
buf.push('\n');
}
for (sym, val) in &self.outputs {
let name = sym.as_str();
let mut chars = name.chars();
if let Some(first) = chars.next() {
buf.push(first.to_ascii_uppercase());
buf.extend(chars);
}
buf.push_str(": ");
buf.push_str(val);
buf.push('\n');
}
buf
}
}
impl<'a> Default for TypedDemo<'a> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::intern::{sym, ANSWER, QUESTION, SCORE as SYM_SCORE};
#[test]
fn typed_field_is_8_bytes() {
assert_eq!(std::mem::size_of::<TypedField>(), 8);
}
#[test]
fn value_kind_is_1_byte() {
assert_eq!(std::mem::size_of::<ValueKind>(), 1);
}
#[test]
fn direction_is_1_byte() {
assert_eq!(std::mem::size_of::<Direction>(), 1);
}
#[test]
fn typed_signature_is_copy() {
const SIG: TypedSignature = TypedSignature::new("test");
let a = SIG;
let b = a; assert_eq!(a.instruction(), b.instruction());
}
#[test]
fn const_construction_basic() {
const SIG: TypedSignature = TypedSignature::new("Answer questions.")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str);
assert_eq!(SIG.instruction(), "Answer questions.");
assert_eq!(SIG.input_count(), 1);
assert_eq!(SIG.output_count(), 1);
let inp = SIG.inputs();
assert_eq!(inp[0].name, QUESTION);
assert_eq!(inp[0].kind, ValueKind::Str);
assert!(inp[0].required);
assert_eq!(inp[0].direction, Direction::Input);
let out = SIG.outputs();
assert_eq!(out[0].name, ANSWER);
assert_eq!(out[0].kind, ValueKind::Str);
assert!(out[0].required);
assert_eq!(out[0].direction, Direction::Output);
}
#[test]
fn const_construction_multi_field() {
const SIG: TypedSignature = TypedSignature::new("Score an answer.")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(SYM_SCORE, ValueKind::Float);
assert_eq!(SIG.input_count(), 1);
assert_eq!(SIG.output_count(), 2);
assert_eq!(SIG.outputs()[1].kind, ValueKind::Float);
}
#[test]
fn const_construction_optional_fields() {
const SIG: TypedSignature = TypedSignature::new("Optional test")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output_optional(SYM_SCORE, ValueKind::Float);
assert!(SIG.outputs()[0].required);
assert!(!SIG.outputs()[1].required);
}
#[test]
fn const_construction_max_fields() {
let context = sym("context");
let reasoning = sym("reasoning");
let evidence = sym("evidence");
const SIG: TypedSignature = TypedSignature::new("Full")
.input(QUESTION, ValueKind::Str)
.input(ANSWER, ValueKind::Str) .input(QUESTION, ValueKind::Str)
.input(ANSWER, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(QUESTION, ValueKind::Int)
.output(ANSWER, ValueKind::Float)
.output(QUESTION, ValueKind::Bool);
assert_eq!(SIG.input_count(), 4);
assert_eq!(SIG.output_count(), 4);
let rt_sig = TypedSignature::new("Runtime")
.input(context, ValueKind::Str)
.output(reasoning, ValueKind::Str)
.output(evidence, ValueKind::Str);
assert_eq!(rt_sig.input_count(), 1);
assert_eq!(rt_sig.output_count(), 2);
}
#[test]
fn value_kind_str_matches_anything() {
assert!(ValueKind::Str.matches("hello"));
assert!(ValueKind::Str.matches(""));
assert!(ValueKind::Str.matches("42"));
}
#[test]
fn value_kind_int_matches() {
assert!(ValueKind::Int.matches("42"));
assert!(ValueKind::Int.matches("-7"));
assert!(ValueKind::Int.matches(" 100 "));
assert!(!ValueKind::Int.matches("3.14"));
assert!(!ValueKind::Int.matches("abc"));
}
#[test]
fn value_kind_float_matches() {
assert!(ValueKind::Float.matches("3.14"));
assert!(ValueKind::Float.matches("-0.5"));
assert!(ValueKind::Float.matches("42")); assert!(ValueKind::Float.matches(" 1e10 "));
assert!(!ValueKind::Float.matches("abc"));
}
#[test]
fn value_kind_bool_matches() {
assert!(ValueKind::Bool.matches("true"));
assert!(ValueKind::Bool.matches("False"));
assert!(ValueKind::Bool.matches("YES"));
assert!(ValueKind::Bool.matches("no"));
assert!(ValueKind::Bool.matches("1"));
assert!(ValueKind::Bool.matches("0"));
assert!(!ValueKind::Bool.matches("maybe"));
assert!(!ValueKind::Bool.matches(""));
}
#[test]
fn value_kind_json_object_matches() {
assert!(ValueKind::JsonObject.matches(r#"{"key": "value"}"#));
assert!(ValueKind::JsonObject.matches("{}"));
assert!(!ValueKind::JsonObject.matches("[1,2]"));
assert!(!ValueKind::JsonObject.matches("hello"));
}
#[test]
fn value_kind_json_array_matches() {
assert!(ValueKind::JsonArray.matches("[1, 2, 3]"));
assert!(ValueKind::JsonArray.matches("[]"));
assert!(!ValueKind::JsonArray.matches("{}"));
assert!(!ValueKind::JsonArray.matches("hello"));
}
#[test]
fn value_kind_enum_matches() {
assert!(ValueKind::Enum.matches("Option1"));
assert!(!ValueKind::Enum.matches(""));
assert!(!ValueKind::Enum.matches(" "));
}
#[test]
fn parsed_output_basic() {
let text = "Answer: Paris\nScore: 0.95\n";
let view = StrView::new(text);
let mut parsed = ParsedOutput::new(view);
parsed.push(ANSWER, FieldRange::new(8, 13)); parsed.push(SYM_SCORE, FieldRange::new(21, 25));
assert_eq!(parsed.get_str(ANSWER), Some("Paris"));
assert_eq!(parsed.get_str(SYM_SCORE), Some("0.95"));
assert_eq!(parsed.get_float(SYM_SCORE), Some(0.95));
assert_eq!(parsed.get_str(QUESTION), None);
assert_eq!(parsed.field_count(), 2);
}
#[test]
fn parsed_output_get_int() {
let text = "Count: 42";
let view = StrView::new(text);
let count_sym = sym("count");
let mut parsed = ParsedOutput::new(view);
parsed.push(count_sym, FieldRange::new(7, 9));
assert_eq!(parsed.get_int(count_sym), Some(42));
}
#[test]
fn parsed_output_get_bool() {
let text = "Valid: true";
let view = StrView::new(text);
let valid_sym = sym("valid");
let mut parsed = ParsedOutput::new(view);
parsed.push(valid_sym, FieldRange::new(7, 11));
assert_eq!(parsed.get_bool(valid_sym), Some(true));
}
#[test]
fn parsed_output_parse_from_sig() {
const SIG: TypedSignature = TypedSignature::new("Score an answer.")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(SYM_SCORE, ValueKind::Float);
let text = "Answer: The capital of France is Paris\nScore: 0.95\n";
let view = StrView::new(text);
let parsed = ParsedOutput::parse_from_sig(view, &SIG);
assert_eq!(parsed.field_count(), 2);
assert_eq!(
parsed.get_str(ANSWER),
Some("The capital of France is Paris")
);
assert_eq!(parsed.get_float(SYM_SCORE), Some(0.95));
}
#[test]
fn parsed_output_parse_from_sig_missing_field() {
const SIG: TypedSignature = TypedSignature::new("Score an answer.")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(SYM_SCORE, ValueKind::Float);
let text = "Answer: Paris\n";
let view = StrView::new(text);
let parsed = ParsedOutput::parse_from_sig(view, &SIG);
assert_eq!(parsed.field_count(), 1);
assert_eq!(parsed.get_str(ANSWER), Some("Paris"));
assert_eq!(parsed.get_float(SYM_SCORE), None);
}
#[test]
fn parsed_output_raw_accessor() {
let text = "hello world";
let view = StrView::new(text);
let parsed = ParsedOutput::new(view);
assert_eq!(parsed.raw().as_str(), "hello world");
}
#[test]
fn parsed_output_with_fields() {
let text = "Answer: yes";
let view = StrView::new(text);
let fields: SmallVec<[(Sym, FieldRange); 4]> =
smallvec::smallvec![(ANSWER, FieldRange::new(8, 11))];
let parsed = ParsedOutput::with_fields(view, fields);
assert_eq!(parsed.get_str(ANSWER), Some("yes"));
}
#[test]
fn validator_all_fields_present_and_correct() {
const SIG: TypedSignature = TypedSignature::new("Score an answer.")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(SYM_SCORE, ValueKind::Float);
let v = SIG.validator();
let score = v.validate("Answer: Paris\nScore: 0.95");
assert!(score.is_perfect(), "score = {:?}", score);
}
#[test]
fn validator_missing_required_field() {
const SIG: TypedSignature = TypedSignature::new("Score an answer.")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(SYM_SCORE, ValueKind::Float);
let v = SIG.validator();
let score = v.validate("Answer: Paris");
assert!((score.value - 0.5).abs() < f64::EPSILON);
assert!(score
.feedback_str()
.unwrap()
.contains("Missing required field"));
}
#[test]
fn validator_wrong_type() {
const SIG: TypedSignature =
TypedSignature::new("Count things.").output(SYM_SCORE, ValueKind::Int);
let v = SIG.validator();
let score = v.validate("Score: not_a_number");
assert!((score.value - 0.0).abs() < f64::EPSILON);
assert!(score.feedback_str().unwrap().contains("expected integer"));
}
#[test]
fn validator_optional_field_not_required() {
const SIG: TypedSignature = TypedSignature::new("Optional test")
.output(ANSWER, ValueKind::Str)
.output_optional(SYM_SCORE, ValueKind::Float);
let v = SIG.validator();
let score = v.validate("Answer: Paris");
assert!(score.is_perfect());
}
#[test]
fn validator_no_outputs() {
const SIG: TypedSignature =
TypedSignature::new("Input only").input(QUESTION, ValueKind::Str);
let v = SIG.validator();
let score = v.validate("anything");
assert!(score.is_perfect());
}
#[test]
fn validator_name() {
const SIG: TypedSignature = TypedSignature::new("test");
let v = SIG.validator();
assert_eq!(v.name(), "typed_field_validator");
}
#[test]
fn validator_via_constructor() {
const SIG: TypedSignature = TypedSignature::new("test").output(ANSWER, ValueKind::Str);
let v = TypedFieldValidator::new(SIG);
assert_eq!(v.signature().instruction(), "test");
assert!(v.validate("Answer: hello").is_perfect());
}
#[test]
fn typed_demo_basic() {
let demo = TypedDemo::new()
.input(QUESTION, "What is the capital of France?")
.output(ANSWER, "Paris");
assert_eq!(demo.inputs.len(), 1);
assert_eq!(demo.outputs.len(), 1);
let formatted = demo.format();
assert!(formatted.contains("Question: What is the capital of France?"));
assert!(formatted.contains("Answer: Paris"));
}
#[test]
fn typed_demo_default() {
let demo = TypedDemo::default();
assert!(demo.inputs.is_empty());
assert!(demo.outputs.is_empty());
}
#[test]
fn typed_demo_multi_field() {
let demo = TypedDemo::new()
.input(QUESTION, "2+2?")
.output(ANSWER, "4")
.output(SYM_SCORE, "1.0");
let formatted = demo.format();
assert!(formatted.contains("Question: 2+2?"));
assert!(formatted.contains("Answer: 4"));
assert!(formatted.contains("Score: 1.0"));
}
#[test]
fn get_input_output_lookup() {
const SIG: TypedSignature = TypedSignature::new("Lookup test")
.input(QUESTION, ValueKind::Str)
.output(ANSWER, ValueKind::Str)
.output(SYM_SCORE, ValueKind::Float);
assert!(SIG.get_input(QUESTION).is_some());
assert!(SIG.get_input(ANSWER).is_none());
assert!(SIG.get_output(ANSWER).is_some());
assert!(SIG.get_output(SYM_SCORE).is_some());
assert!(SIG.get_output(QUESTION).is_none());
}
#[test]
fn value_kind_labels() {
assert_eq!(ValueKind::Str.label(), "string");
assert_eq!(ValueKind::Int.label(), "integer");
assert_eq!(ValueKind::Float.label(), "float");
assert_eq!(ValueKind::Bool.label(), "boolean");
assert_eq!(ValueKind::JsonObject.label(), "JSON object");
assert_eq!(ValueKind::JsonArray.label(), "JSON array");
assert_eq!(ValueKind::Enum.label(), "enum");
}
#[test]
fn parsed_output_whitespace_trimming() {
let text = "Answer: Paris \n";
let view = StrView::new(text);
let mut parsed = ParsedOutput::new(view);
parsed.push(ANSWER, FieldRange::new(7, 18));
assert_eq!(parsed.get_raw(ANSWER), Some(" Paris "));
assert_eq!(parsed.get_str(ANSWER), Some("Paris"));
}
#[test]
fn parsed_output_iter() {
let text = "Answer: a\nScore: 1";
let view = StrView::new(text);
let mut parsed = ParsedOutput::new(view);
parsed.push(ANSWER, FieldRange::new(8, 9));
parsed.push(SYM_SCORE, FieldRange::new(17, 18));
let pairs: Vec<_> = parsed.iter().collect();
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0].0, ANSWER);
assert_eq!(pairs[1].0, SYM_SCORE);
}
#[test]
fn field_range_as_range() {
let fr = FieldRange::new(5, 10);
assert_eq!(fr.as_range(), 5..10);
assert_eq!(fr.len(), 5);
assert!(!fr.is_empty());
let empty_fr = FieldRange::new(3, 3);
assert!(empty_fr.is_empty());
}
}