provable_contracts/
pipeline.rs1use std::path::Path;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::ContractError;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PipelineContract {
18 pub metadata: PipelineMetadata,
19 #[serde(default)]
20 pub stages: Vec<PipelineStage>,
21 #[serde(default)]
22 pub cross_boundary_obligations: Vec<CrossBoundaryObligation>,
23 #[serde(default)]
24 pub performance_contract: Option<PerformanceContract>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PipelineMetadata {
30 pub version: String,
31 #[serde(default)]
32 pub created: Option<String>,
33 #[serde(default)]
34 pub author: Option<String>,
35 pub description: String,
36 #[serde(default)]
37 pub pipeline: bool,
38 #[serde(default)]
39 pub references: Vec<String>,
40 #[serde(default)]
41 pub depends_on: Vec<String>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct PipelineStage {
47 pub name: String,
48 #[serde(default)]
49 pub repo: Option<String>,
50 #[serde(default)]
51 pub contract: Option<String>,
52 #[serde(default)]
53 pub equation: Option<String>,
54 #[serde(default)]
55 pub input_requires: Option<String>,
56 #[serde(default)]
57 pub output_invariant: Option<String>,
58 #[serde(default)]
59 pub output_shape: Option<String>,
60 #[serde(default)]
61 pub repeat: Option<String>,
62 #[serde(default)]
63 pub substages: Vec<PipelineStage>,
64 #[serde(default)]
65 pub depends_on_contracts: Vec<String>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct CrossBoundaryObligation {
71 pub id: String,
72 pub property: String,
73 pub from_stage: String,
74 pub to_stage: String,
75 pub formal: String,
76 #[serde(default)]
77 pub verification: Option<String>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct PerformanceContract {
83 #[serde(default)]
84 pub roofline: Option<String>,
85 #[serde(default)]
86 pub prefill_bound: Option<String>,
87 #[serde(default)]
88 pub decode_bound: Option<String>,
89 #[serde(default)]
90 pub throughput_ceiling: Option<String>,
91}
92
93pub fn parse_pipeline(path: &Path) -> Result<PipelineContract, ContractError> {
95 let content = std::fs::read_to_string(path)?;
96 parse_pipeline_str(&content)
97}
98
99pub fn parse_pipeline_str(yaml: &str) -> Result<PipelineContract, ContractError> {
101 let pipeline: PipelineContract = serde_yaml::from_str(yaml)?;
102 Ok(pipeline)
103}
104
105pub fn validate_pipeline(pipeline: &PipelineContract) -> Vec<PipelineIssue> {
112 let mut issues = Vec::new();
113
114 let stage_names: Vec<String> = collect_stage_names(&pipeline.stages);
116
117 let mut seen = std::collections::HashSet::new();
119 for name in &stage_names {
120 if !seen.insert(name.as_str()) {
121 issues.push(PipelineIssue {
122 severity: IssueSeverity::Warning,
123 message: format!("Duplicate stage name: {name}"),
124 });
125 }
126 }
127
128 for ob in &pipeline.cross_boundary_obligations {
130 if !stage_names.contains(&ob.from_stage) {
131 issues.push(PipelineIssue {
132 severity: IssueSeverity::Error,
133 message: format!(
134 "Obligation {} references unknown from_stage: {}",
135 ob.id, ob.from_stage
136 ),
137 });
138 }
139 if !stage_names.contains(&ob.to_stage) {
140 issues.push(PipelineIssue {
141 severity: IssueSeverity::Error,
142 message: format!(
143 "Obligation {} references unknown to_stage: {}",
144 ob.id, ob.to_stage
145 ),
146 });
147 }
148 }
149
150 for window in pipeline.stages.windows(2) {
152 let prev = &window[0];
153 let next = &window[1];
154 if prev.output_invariant.is_none() {
155 issues.push(PipelineIssue {
156 severity: IssueSeverity::Warning,
157 message: format!("Stage '{}' has no output_invariant", prev.name),
158 });
159 }
160 if next.input_requires.is_none() {
161 issues.push(PipelineIssue {
162 severity: IssueSeverity::Warning,
163 message: format!("Stage '{}' has no input_requires", next.name),
164 });
165 }
166 }
167
168 issues
169}
170
171fn collect_stage_names(stages: &[PipelineStage]) -> Vec<String> {
173 let mut names = Vec::new();
174 for stage in stages {
175 names.push(stage.name.clone());
176 if !stage.substages.is_empty() {
177 names.extend(collect_stage_names(&stage.substages));
178 }
179 }
180 names
181}
182
183#[derive(Debug, Clone)]
185pub struct PipelineIssue {
186 pub severity: IssueSeverity,
187 pub message: String,
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum IssueSeverity {
193 Error,
194 Warning,
195}
196
197impl std::fmt::Display for IssueSeverity {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 match self {
200 Self::Error => write!(f, "ERROR"),
201 Self::Warning => write!(f, "WARN"),
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn parse_minimal_pipeline() {
212 let yaml = r#"
213metadata:
214 version: "1.0.0"
215 description: "Test pipeline"
216 pipeline: true
217stages:
218 - name: stage_a
219 repo: repo_a
220 contract: contract-v1
221 equation: eq_a
222 output_invariant: "x > 0"
223 - name: stage_b
224 repo: repo_b
225 contract: contract-v1
226 equation: eq_b
227 input_requires: "x > 0"
228cross_boundary_obligations:
229 - id: PIPE-001
230 property: "A feeds B"
231 from_stage: stage_a
232 to_stage: stage_b
233 formal: "output(a) satisfies input(b)"
234"#;
235 let pipeline = parse_pipeline_str(yaml).unwrap();
236 assert_eq!(pipeline.stages.len(), 2);
237 assert_eq!(pipeline.cross_boundary_obligations.len(), 1);
238 assert!(pipeline.metadata.pipeline);
239 }
240
241 #[test]
242 fn validate_valid_pipeline() {
243 let yaml = r#"
244metadata:
245 version: "1.0.0"
246 description: "Test"
247 pipeline: true
248stages:
249 - name: a
250 output_invariant: "x > 0"
251 - name: b
252 input_requires: "x > 0"
253 output_invariant: "y > 0"
254cross_boundary_obligations:
255 - id: P1
256 property: "a→b"
257 from_stage: a
258 to_stage: b
259 formal: "ok"
260"#;
261 let pipeline = parse_pipeline_str(yaml).unwrap();
262 let issues = validate_pipeline(&pipeline);
263 let errors: Vec<_> = issues
264 .iter()
265 .filter(|i| i.severity == IssueSeverity::Error)
266 .collect();
267 assert!(errors.is_empty());
268 }
269
270 #[test]
271 fn validate_bad_stage_ref() {
272 let yaml = r#"
273metadata:
274 version: "1.0.0"
275 description: "Test"
276 pipeline: true
277stages:
278 - name: a
279cross_boundary_obligations:
280 - id: P1
281 property: "bad ref"
282 from_stage: a
283 to_stage: nonexistent
284 formal: "fail"
285"#;
286 let pipeline = parse_pipeline_str(yaml).unwrap();
287 let issues = validate_pipeline(&pipeline);
288 let errors: Vec<_> = issues
289 .iter()
290 .filter(|i| i.severity == IssueSeverity::Error)
291 .collect();
292 assert_eq!(errors.len(), 1);
293 assert!(errors[0].message.contains("nonexistent"));
294 }
295
296 #[test]
297 fn parse_inference_forward() {
298 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
299 .join("../../contracts/pipelines/inference-forward-v1.yaml");
300 if path.exists() {
301 let pipeline = parse_pipeline(&path).unwrap();
302 assert!(pipeline.metadata.pipeline);
303 assert!(!pipeline.stages.is_empty());
304 assert!(!pipeline.cross_boundary_obligations.is_empty());
305 let issues = validate_pipeline(&pipeline);
306 let errors: Vec<_> = issues
307 .iter()
308 .filter(|i| i.severity == IssueSeverity::Error)
309 .collect();
310 assert!(errors.is_empty(), "Errors: {errors:?}");
311 }
312 }
313
314 #[test]
315 fn substage_names_collected() {
316 let yaml = r#"
317metadata:
318 version: "1.0.0"
319 description: "Test"
320 pipeline: true
321stages:
322 - name: outer
323 substages:
324 - name: inner_a
325 - name: inner_b
326cross_boundary_obligations:
327 - id: P1
328 property: "inner ref"
329 from_stage: inner_a
330 to_stage: inner_b
331 formal: "ok"
332"#;
333 let pipeline = parse_pipeline_str(yaml).unwrap();
334 let issues = validate_pipeline(&pipeline);
335 let errors: Vec<_> = issues
336 .iter()
337 .filter(|i| i.severity == IssueSeverity::Error)
338 .collect();
339 assert!(errors.is_empty());
340 }
341}