use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::error::ContractError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BindingRegistry {
pub version: String,
pub target_crate: String,
#[serde(default)]
pub bindings: Vec<KernelBinding>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelBinding {
pub contract: String,
pub equation: String,
#[serde(default)]
pub module_path: Option<String>,
#[serde(default)]
pub function: Option<String>,
#[serde(default)]
pub signature: Option<String>,
pub status: ImplStatus,
#[serde(default)]
pub notes: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImplStatus {
Implemented,
Partial,
NotImplemented,
}
impl std::fmt::Display for ImplStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Implemented => "implemented",
Self::Partial => "partial",
Self::NotImplemented => "not_implemented",
};
write!(f, "{s}")
}
}
pub fn parse_binding(path: &Path) -> Result<BindingRegistry, ContractError> {
let content = std::fs::read_to_string(path)?;
parse_binding_str(&content)
}
pub fn parse_binding_str(yaml: &str) -> Result<BindingRegistry, ContractError> {
let registry: BindingRegistry = serde_yaml::from_str(yaml)?;
Ok(registry)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_minimal_binding() {
let yaml = r#"
version: "1.0.0"
target_crate: aprender
bindings: []
"#;
let reg = parse_binding_str(yaml).unwrap();
assert_eq!(reg.version, "1.0.0");
assert_eq!(reg.target_crate, "aprender");
assert!(reg.bindings.is_empty());
}
#[test]
fn parse_binding_with_entries() {
let yaml = r#"
version: "1.0.0"
target_crate: aprender
bindings:
- contract: softmax-kernel-v1.yaml
equation: softmax
module_path: "aprender::nn::functional::softmax"
function: softmax
signature: "fn softmax(x: &Tensor, dim: i32) -> Tensor"
status: implemented
- contract: activation-kernel-v1.yaml
equation: silu
status: not_implemented
notes: "Not yet available"
"#;
let reg = parse_binding_str(yaml).unwrap();
assert_eq!(reg.bindings.len(), 2);
assert_eq!(reg.bindings[0].equation, "softmax");
assert_eq!(reg.bindings[0].status, ImplStatus::Implemented);
assert!(reg.bindings[0].module_path.is_some());
assert_eq!(reg.bindings[1].equation, "silu");
assert_eq!(reg.bindings[1].status, ImplStatus::NotImplemented);
assert!(reg.bindings[1].module_path.is_none());
}
#[test]
fn parse_partial_status() {
let yaml = r#"
version: "1.0.0"
target_crate: test
bindings:
- contract: test.yaml
equation: f
module_path: "test::f"
function: f
status: partial
notes: "Only scalar path"
"#;
let reg = parse_binding_str(yaml).unwrap();
assert_eq!(reg.bindings[0].status, ImplStatus::Partial);
}
#[test]
fn impl_status_display() {
assert_eq!(ImplStatus::Implemented.to_string(), "implemented");
assert_eq!(ImplStatus::Partial.to_string(), "partial");
assert_eq!(ImplStatus::NotImplemented.to_string(), "not_implemented");
}
#[test]
fn parse_invalid_binding_yaml() {
let result = parse_binding_str("not: [valid: {{");
assert!(result.is_err());
}
#[test]
fn parse_binding_from_file() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../contracts/aprender/binding.yaml");
let reg = parse_binding(&path).unwrap();
assert_eq!(reg.target_crate, "aprender");
assert!(!reg.bindings.is_empty());
}
#[test]
fn parse_binding_nonexistent_file() {
let result = parse_binding(std::path::Path::new("/nonexistent/binding.yaml"));
assert!(result.is_err());
}
}