Skip to main content

batuta/serve/banco/
recipes.rs

1//! Data recipe engine — declarative pipelines that transform files into datasets.
2//!
3//! A recipe is a sequence of steps: extract_text → chunk → filter → format.
4//! Each step transforms records, producing a dataset suitable for training.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10/// A data recipe definition.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Recipe {
13    pub id: String,
14    pub name: String,
15    pub source_files: Vec<String>,
16    pub steps: Vec<RecipeStep>,
17    pub output_format: String,
18    pub created_at: u64,
19    pub status: RecipeStatus,
20}
21
22/// Recipe execution status.
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
24#[serde(rename_all = "snake_case")]
25pub enum RecipeStatus {
26    Created,
27    Running,
28    Complete,
29    Failed,
30}
31
32/// A single step in a recipe pipeline.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RecipeStep {
35    #[serde(rename = "type")]
36    pub step_type: StepType,
37    #[serde(default)]
38    pub config: serde_json::Value,
39}
40
41/// Built-in recipe step types.
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43#[serde(rename_all = "snake_case")]
44pub enum StepType {
45    ExtractText,
46    ParseCsv,
47    ParseJsonl,
48    Chunk,
49    Filter,
50    Format,
51    Deduplicate,
52}
53
54/// A single record flowing through the pipeline.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Record {
57    pub text: String,
58    #[serde(default)]
59    pub metadata: HashMap<String, String>,
60}
61
62/// Result of running a recipe.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct DatasetResult {
65    pub dataset_id: String,
66    pub recipe_id: String,
67    pub record_count: usize,
68    pub records: Vec<Record>,
69}
70
71/// Recipe store — manages recipe definitions and execution.
72pub struct RecipeStore {
73    recipes: RwLock<HashMap<String, Recipe>>,
74    datasets: RwLock<HashMap<String, DatasetResult>>,
75    counter: std::sync::atomic::AtomicU64,
76}
77
78impl RecipeStore {
79    #[must_use]
80    pub fn new() -> Arc<Self> {
81        Arc::new(Self {
82            recipes: RwLock::new(HashMap::new()),
83            datasets: RwLock::new(HashMap::new()),
84            counter: std::sync::atomic::AtomicU64::new(0),
85        })
86    }
87
88    /// Create a recipe definition.
89    pub fn create(
90        &self,
91        name: &str,
92        source_files: Vec<String>,
93        steps: Vec<RecipeStep>,
94        output_format: &str,
95    ) -> Recipe {
96        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
97        let recipe = Recipe {
98            id: format!("recipe-{}-{seq}", epoch_secs()),
99            name: name.to_string(),
100            source_files,
101            steps,
102            output_format: output_format.to_string(),
103            created_at: epoch_secs(),
104            status: RecipeStatus::Created,
105        };
106        if let Ok(mut store) = self.recipes.write() {
107            store.insert(recipe.id.clone(), recipe.clone());
108        }
109        recipe
110    }
111
112    /// List all recipes.
113    #[must_use]
114    pub fn list(&self) -> Vec<Recipe> {
115        let store = self.recipes.read().unwrap_or_else(|e| e.into_inner());
116        let mut recipes: Vec<Recipe> = store.values().cloned().collect();
117        recipes.sort_by(|a, b| b.created_at.cmp(&a.created_at));
118        recipes
119    }
120
121    /// Get a recipe by ID.
122    #[must_use]
123    pub fn get(&self, id: &str) -> Option<Recipe> {
124        self.recipes.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
125    }
126
127    /// Run a recipe against source texts, producing a dataset.
128    pub fn run(
129        &self,
130        recipe_id: &str,
131        source_texts: &[(&str, &str)],
132    ) -> Result<DatasetResult, RecipeError> {
133        let recipe = self.get(recipe_id).ok_or(RecipeError::NotFound(recipe_id.to_string()))?;
134
135        // Update status
136        if let Ok(mut store) = self.recipes.write() {
137            if let Some(r) = store.get_mut(recipe_id) {
138                r.status = RecipeStatus::Running;
139            }
140        }
141
142        // Initialize records from source texts
143        let mut records: Vec<Record> = source_texts
144            .iter()
145            .map(|(name, text)| Record {
146                text: text.to_string(),
147                metadata: [("source".to_string(), name.to_string())].into(),
148            })
149            .collect();
150
151        // Execute pipeline steps
152        for step in &recipe.steps {
153            records = execute_step(step, records)?;
154        }
155
156        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
157        let result = DatasetResult {
158            dataset_id: format!("ds-{}-{seq}", epoch_secs()),
159            recipe_id: recipe_id.to_string(),
160            record_count: records.len(),
161            records,
162        };
163
164        // Store dataset
165        if let Ok(mut store) = self.datasets.write() {
166            store.insert(result.dataset_id.clone(), result.clone());
167        }
168
169        // Update status
170        if let Ok(mut store) = self.recipes.write() {
171            if let Some(r) = store.get_mut(recipe_id) {
172                r.status = RecipeStatus::Complete;
173            }
174        }
175
176        Ok(result)
177    }
178
179    /// Get a dataset by ID.
180    #[must_use]
181    pub fn get_dataset(&self, id: &str) -> Option<DatasetResult> {
182        self.datasets.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
183    }
184
185    /// List all datasets.
186    #[must_use]
187    pub fn list_datasets(&self) -> Vec<DatasetResult> {
188        let store = self.datasets.read().unwrap_or_else(|e| e.into_inner());
189        store.values().cloned().collect()
190    }
191}
192
193/// Execute a single pipeline step on a set of records.
194fn execute_step(step: &RecipeStep, records: Vec<Record>) -> Result<Vec<Record>, RecipeError> {
195    match step.step_type {
196        StepType::ExtractText => Ok(records), // Already text — passthrough
197        StepType::ParseCsv => execute_parse_csv(step, records),
198        StepType::ParseJsonl => execute_parse_jsonl(step, records),
199        StepType::Chunk => execute_chunk(step, records),
200        StepType::Filter => execute_filter(step, records),
201        StepType::Format => execute_format(step, records),
202        StepType::Deduplicate => execute_dedup(records),
203    }
204}
205
206/// Parse CSV: convert CSV records into text records using alimentar (or simple fallback).
207fn execute_parse_csv(step: &RecipeStep, records: Vec<Record>) -> Result<Vec<Record>, RecipeError> {
208    let text_column = step.config.get("text_column").and_then(|v| v.as_str());
209    let delimiter = step
210        .config
211        .get("delimiter")
212        .and_then(|v| v.as_str())
213        .and_then(|s| s.as_bytes().first().copied())
214        .unwrap_or(b',');
215
216    let mut output = Vec::new();
217
218    for record in &records {
219        let parsed = parse_csv_content(&record.text, text_column, delimiter);
220        for (i, text) in parsed.into_iter().enumerate() {
221            let mut meta = record.metadata.clone();
222            meta.insert("row_index".to_string(), i.to_string());
223            output.push(Record { text, metadata: meta });
224        }
225    }
226
227    Ok(output)
228}
229
230/// Parse CSV content with alimentar validation (ml feature) or simple fallback.
231///
232/// With `ml`: validates CSV via alimentar's Arrow parser, extracts row count + schema.
233/// Both paths use the simple line-based extractor for text content.
234#[cfg(feature = "alimentar")]
235fn parse_csv_content(csv_text: &str, text_column: Option<&str>, delimiter: u8) -> Vec<String> {
236    use alimentar::{ArrowDataset, Dataset};
237
238    // Validate with alimentar — if it fails, CSV is malformed
239    match ArrowDataset::from_csv_str(csv_text) {
240        Ok(ds) => {
241            let schema = ds.schema();
242            // If text_column is specified but doesn't exist in schema, warn via fallback
243            if let Some(col) = text_column {
244                if !schema.fields().iter().any(|f| f.name() == col) {
245                    eprintln!(
246                        "[banco] Warning: column '{}' not found in CSV (available: {})",
247                        col,
248                        schema
249                            .fields()
250                            .iter()
251                            .map(|f| f.name().as_str())
252                            .collect::<Vec<_>>()
253                            .join(", ")
254                    );
255                }
256            }
257        }
258        Err(e) => {
259            eprintln!("[banco] CSV parse warning: {e}");
260        }
261    }
262
263    // Use the simple line-based extractor (works for all delimiters)
264    parse_csv_fallback(csv_text, text_column, delimiter)
265}
266
267/// Fallback CSV parsing without alimentar.
268#[cfg(not(feature = "alimentar"))]
269fn parse_csv_content(csv_text: &str, text_column: Option<&str>, delimiter: u8) -> Vec<String> {
270    parse_csv_fallback(csv_text, text_column, delimiter)
271}
272
273/// Simple line-based CSV fallback (no Arrow).
274fn parse_csv_fallback(csv_text: &str, text_column: Option<&str>, delimiter: u8) -> Vec<String> {
275    let delim = delimiter as char;
276    let mut lines = csv_text.lines();
277    let header = match lines.next() {
278        Some(h) => h,
279        None => return Vec::new(),
280    };
281
282    let col_idx = text_column.and_then(|name| header.split(delim).position(|h| h.trim() == name));
283
284    lines
285        .filter(|line| !line.trim().is_empty())
286        .map(|line| {
287            if let Some(idx) = col_idx {
288                line.split(delim).nth(idx).unwrap_or("").trim().to_string()
289            } else {
290                line.split(delim).map(|s| s.trim()).collect::<Vec<_>>().join(" | ")
291            }
292        })
293        .filter(|s| !s.is_empty())
294        .collect()
295}
296
297/// Parse JSONL: convert JSON lines into text records.
298fn execute_parse_jsonl(
299    step: &RecipeStep,
300    records: Vec<Record>,
301) -> Result<Vec<Record>, RecipeError> {
302    let text_field = step.config.get("text_field").and_then(|v| v.as_str());
303
304    let mut output = Vec::new();
305    for record in &records {
306        for (i, line) in record.text.lines().enumerate() {
307            let line = line.trim();
308            if line.is_empty() {
309                continue;
310            }
311            let text = if let Ok(obj) = serde_json::from_str::<serde_json::Value>(line) {
312                if let Some(field) = text_field {
313                    obj.get(field).and_then(|v| v.as_str()).unwrap_or("").to_string()
314                } else {
315                    // Use first string field, or stringify the whole object
316                    obj.as_object()
317                        .and_then(|o| o.values().find_map(|v| v.as_str().map(String::from)))
318                        .unwrap_or_else(|| obj.to_string())
319                }
320            } else {
321                line.to_string()
322            };
323            if !text.is_empty() {
324                let mut meta = record.metadata.clone();
325                meta.insert("line_index".to_string(), i.to_string());
326                output.push(Record { text, metadata: meta });
327            }
328        }
329    }
330    Ok(output)
331}
332
333/// Chunk: split text into token-aware pieces.
334fn execute_chunk(step: &RecipeStep, records: Vec<Record>) -> Result<Vec<Record>, RecipeError> {
335    let max_chars =
336        step.config.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(512) as usize * 4;
337    let raw_overlap =
338        step.config.get("overlap").and_then(|v| v.as_u64()).unwrap_or(64) as usize * 4;
339    // Clamp overlap to half the chunk size to prevent subtraction overflow
340    let overlap = raw_overlap.min(max_chars / 2);
341
342    let mut chunks = Vec::new();
343    for record in &records {
344        let text = &record.text;
345        if text.len() <= max_chars {
346            chunks.push(record.clone());
347            continue;
348        }
349
350        let mut start = 0;
351        let mut chunk_idx = 0;
352        while start < text.len() {
353            let end = (start + max_chars).min(text.len());
354            let chunk_text = &text[start..end];
355            let mut meta = record.metadata.clone();
356            meta.insert("chunk_index".to_string(), chunk_idx.to_string());
357            chunks.push(Record { text: chunk_text.to_string(), metadata: meta });
358            start = if end == text.len() { end } else { end - overlap };
359            chunk_idx += 1;
360        }
361    }
362    Ok(chunks)
363}
364
365/// Filter: remove records by min/max length.
366fn execute_filter(step: &RecipeStep, records: Vec<Record>) -> Result<Vec<Record>, RecipeError> {
367    let min_len = step.config.get("min_length").and_then(|v| v.as_u64()).unwrap_or(1) as usize;
368    let max_len =
369        step.config.get("max_length").and_then(|v| v.as_u64()).unwrap_or(u64::MAX) as usize;
370
371    Ok(records.into_iter().filter(|r| r.text.len() >= min_len && r.text.len() <= max_len).collect())
372}
373
374/// Format: apply a chat template to records.
375fn execute_format(step: &RecipeStep, records: Vec<Record>) -> Result<Vec<Record>, RecipeError> {
376    let template = step.config.get("template").and_then(|v| v.as_str()).unwrap_or("chatml");
377
378    Ok(records
379        .into_iter()
380        .map(|r| {
381            let formatted = match template {
382                "chatml" => {
383                    format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", r.text)
384                }
385                "alpaca" => format!("### Instruction:\n{}\n\n### Response:\n", r.text),
386                "llama2" => format!("[INST] {} [/INST]", r.text),
387                _ => r.text.clone(),
388            };
389            let mut meta = r.metadata;
390            meta.insert("template".to_string(), template.to_string());
391            Record { text: formatted, metadata: meta }
392        })
393        .collect())
394}
395
396/// Deduplicate: remove exact duplicate texts.
397fn execute_dedup(records: Vec<Record>) -> Result<Vec<Record>, RecipeError> {
398    let mut seen = std::collections::HashSet::new();
399    Ok(records.into_iter().filter(|r| seen.insert(r.text.clone())).collect())
400}
401
402/// Recipe errors.
403#[derive(Debug, Clone, PartialEq, Eq)]
404pub enum RecipeError {
405    NotFound(String),
406}
407
408impl std::fmt::Display for RecipeError {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        match self {
411            Self::NotFound(id) => write!(f, "Recipe not found: {id}"),
412        }
413    }
414}
415
416impl std::error::Error for RecipeError {}
417
418fn epoch_secs() -> u64 {
419    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
420}