provable_contracts/
binding.rs1use std::path::Path;
13
14use serde::{Deserialize, Serialize};
15
16use crate::error::ContractError;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct BindingRegistry {
21 pub version: String,
22 pub target_crate: String,
23 #[serde(default)]
26 pub critical_path: Vec<String>,
27 #[serde(default)]
28 pub bindings: Vec<KernelBinding>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct KernelBinding {
34 pub contract: String,
36 pub equation: String,
38 #[serde(default)]
40 pub module_path: Option<String>,
41 #[serde(default)]
43 pub function: Option<String>,
44 #[serde(default)]
46 pub signature: Option<String>,
47 pub status: ImplStatus,
49 #[serde(default)]
51 pub notes: Option<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum ImplStatus {
58 Implemented,
60 Partial,
62 NotImplemented,
64 Pending,
66}
67
68impl std::fmt::Display for ImplStatus {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 let s = match self {
72 Self::Implemented => "implemented",
73 Self::Partial => "partial",
74 Self::NotImplemented => "not_implemented",
75 Self::Pending => "pending",
76 };
77 write!(f, "{s}")
78 }
79}
80
81pub fn parse_binding(path: &Path) -> Result<BindingRegistry, ContractError> {
88 let content = std::fs::read_to_string(path)?;
89 parse_binding_str(&content)
90}
91
92pub fn parse_binding_str(yaml: &str) -> Result<BindingRegistry, ContractError> {
94 let registry: BindingRegistry = serde_yaml::from_str(yaml)?;
95 Ok(registry)
96}
97
98pub fn normalize_contract_id(id: &str) -> &str {
104 id.strip_suffix(".yaml")
105 .or_else(|| id.strip_suffix(".yml"))
106 .unwrap_or(id)
107}
108
109impl BindingRegistry {
110 pub fn bindings_for(&self, contract_id: &str) -> Vec<&KernelBinding> {
112 let needle = normalize_contract_id(contract_id);
113 self.bindings
114 .iter()
115 .filter(|b| normalize_contract_id(&b.contract) == needle)
116 .collect()
117 }
118
119 pub fn find_binding(&self, contract_id: &str, equation: &str) -> Option<&KernelBinding> {
121 let needle = normalize_contract_id(contract_id);
122 self.bindings
123 .iter()
124 .find(|b| normalize_contract_id(&b.contract) == needle && b.equation == equation)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn parse_minimal_binding() {
134 let yaml = r#"
135version: "1.0.0"
136target_crate: aprender
137bindings: []
138"#;
139 let reg = parse_binding_str(yaml).unwrap();
140 assert_eq!(reg.version, "1.0.0");
141 assert_eq!(reg.target_crate, "aprender");
142 assert!(reg.bindings.is_empty());
143 }
144
145 #[test]
146 fn parse_binding_with_entries() {
147 let yaml = r#"
148version: "1.0.0"
149target_crate: aprender
150bindings:
151 - contract: softmax-kernel-v1.yaml
152 equation: softmax
153 module_path: "aprender::nn::functional::softmax"
154 function: softmax
155 signature: "fn softmax(x: &Tensor, dim: i32) -> Tensor"
156 status: implemented
157 - contract: activation-kernel-v1.yaml
158 equation: silu
159 status: not_implemented
160 notes: "Not yet available"
161"#;
162 let reg = parse_binding_str(yaml).unwrap();
163 assert_eq!(reg.bindings.len(), 2);
164 assert_eq!(reg.bindings[0].equation, "softmax");
165 assert_eq!(reg.bindings[0].status, ImplStatus::Implemented);
166 assert!(reg.bindings[0].module_path.is_some());
167 assert_eq!(reg.bindings[1].equation, "silu");
168 assert_eq!(reg.bindings[1].status, ImplStatus::NotImplemented);
169 assert!(reg.bindings[1].module_path.is_none());
170 }
171
172 #[test]
173 fn parse_partial_status() {
174 let yaml = r#"
175version: "1.0.0"
176target_crate: test
177bindings:
178 - contract: test.yaml
179 equation: f
180 module_path: "test::f"
181 function: f
182 status: partial
183 notes: "Only scalar path"
184"#;
185 let reg = parse_binding_str(yaml).unwrap();
186 assert_eq!(reg.bindings[0].status, ImplStatus::Partial);
187 }
188
189 #[test]
190 fn impl_status_display() {
191 assert_eq!(ImplStatus::Implemented.to_string(), "implemented");
192 assert_eq!(ImplStatus::Partial.to_string(), "partial");
193 assert_eq!(ImplStatus::NotImplemented.to_string(), "not_implemented");
194 assert_eq!(ImplStatus::Pending.to_string(), "pending");
195 }
196
197 #[test]
198 fn parse_invalid_binding_yaml() {
199 let result = parse_binding_str("not: [valid: {{");
200 assert!(result.is_err());
201 }
202
203 #[test]
204 fn parse_binding_from_file() {
205 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
206 .join("../../contracts/aprender/binding.yaml");
207 let reg = parse_binding(&path).unwrap();
208 assert_eq!(reg.target_crate, "aprender");
209 assert!(!reg.bindings.is_empty());
210 }
211
212 #[test]
213 fn parse_binding_nonexistent_file() {
214 let result = parse_binding(std::path::Path::new("/nonexistent/binding.yaml"));
215 assert!(result.is_err());
216 }
217}