Skip to main content

noether_engine/executor/
runtime.rs

1//! Runtime executor: handles stages that need external dependencies —
2//! an LLM provider, the stage store, or the semantic index.
3//!
4//! ## Stages handled
5//!
6//! | Stage description                                        | Needs       |
7//! |----------------------------------------------------------|-------------|
8//! | Generate text completion using a language model          | LLM         |
9//! | Generate a vector embedding for text                     | Embedding   |
10//! | Classify text into one of the provided categories        | LLM         |
11//! | Extract structured data from text according to a schema  | LLM         |
12//! | Get detailed information about a stage by its ID         | store cache |
13//! | Search the stage store by semantic query                 | store cache + optional embeddings |
14//! | Check if one type is a structural subtype of another     | pure        |
15//! | Verify that a composition graph type-checks correctly    | store cache |
16
17use super::{ExecutionError, StageExecutor};
18use noether_core::stage::StageId;
19use noether_core::types::NType;
20use noether_store::StageStore;
21use serde_json::{json, Value};
22use std::collections::HashMap;
23use std::sync::Mutex;
24
25use crate::index::embedding::EmbeddingProvider;
26use crate::llm::{LlmConfig, LlmProvider, Message};
27
28// ── Cached stage info (built once at construction) ────────────────────────────
29
30#[derive(Clone)]
31struct CachedStage {
32    id: String,
33    description: String,
34    input_display: String,
35    output_display: String,
36    lifecycle: String,
37    effects: Vec<String>,
38    examples_count: usize,
39}
40
41// ── Cosine similarity ─────────────────────────────────────────────────────────
42
43fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
44    if a.len() != b.len() || a.is_empty() {
45        return 0.0;
46    }
47    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
48    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
49    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
50    if na == 0.0 || nb == 0.0 {
51        0.0
52    } else {
53        dot / (na * nb)
54    }
55}
56
57// ── RuntimeExecutor ───────────────────────────────────────────────────────────
58
59pub struct RuntimeExecutor {
60    llm: Option<Box<dyn LlmProvider>>,
61    llm_config: LlmConfig,
62    embedding_provider: Option<Box<dyn EmbeddingProvider>>,
63    /// stage_id → description (for dispatch)
64    descriptions: HashMap<String, String>,
65    /// Flattened stage list for search and describe
66    stage_cache: Vec<CachedStage>,
67    /// Pre-computed embeddings per stage ID: populated when with_embedding() is called.
68    stage_embeddings: HashMap<String, Vec<f32>>,
69    /// Session-scoped LLM call deduplication: SHA-256(model + prompt) → response.
70    llm_dedup_cache: Mutex<HashMap<String, Value>>,
71}
72
73impl RuntimeExecutor {
74    /// Build from a store. LLM and embedding providers are not required; stages that
75    /// need them will return `ExecutionError::StageFailed` with a clear message.
76    pub fn from_store(store: &dyn StageStore) -> Self {
77        let mut descriptions = HashMap::new();
78        let mut stage_cache = Vec::new();
79
80        for stage in store.list(None) {
81            descriptions.insert(stage.id.0.clone(), stage.description.clone());
82
83            let effects: Vec<String> = stage
84                .signature
85                .effects
86                .iter()
87                .map(|e| format!("{e:?}"))
88                .collect();
89
90            stage_cache.push(CachedStage {
91                id: stage.id.0.clone(),
92                description: stage.description.clone(),
93                input_display: format!("{}", stage.signature.input),
94                output_display: format!("{}", stage.signature.output),
95                lifecycle: format!("{:?}", stage.lifecycle).to_lowercase(),
96                effects,
97                examples_count: stage.examples.len(),
98            });
99        }
100
101        Self {
102            llm: None,
103            llm_config: LlmConfig::default(),
104            embedding_provider: None,
105            descriptions,
106            stage_cache,
107            stage_embeddings: HashMap::new(),
108            llm_dedup_cache: Mutex::new(HashMap::new()),
109        }
110    }
111
112    /// Attach an LLM provider, enabling llm_complete/llm_classify/llm_extract stages.
113    pub fn with_llm(mut self, llm: Box<dyn LlmProvider>, config: LlmConfig) -> Self {
114        self.llm = Some(llm);
115        self.llm_config = config;
116        self
117    }
118
119    /// Attach an embedding provider, enabling real cosine-similarity store_search
120    /// and real llm_embed responses. Pre-computes embeddings for all cached stages.
121    pub fn with_embedding(mut self, provider: Box<dyn EmbeddingProvider>) -> Self {
122        // Pre-compute embeddings for all stage descriptions
123        let mut embeddings = HashMap::new();
124        for stage in &self.stage_cache {
125            if let Ok(emb) = provider.embed(&stage.description) {
126                embeddings.insert(stage.id.clone(), emb);
127            }
128        }
129        self.stage_embeddings = embeddings;
130        self.embedding_provider = Some(provider);
131        self
132    }
133
134    /// Set or replace the LLM provider in-place.
135    pub fn set_llm(&mut self, llm: Box<dyn LlmProvider>, config: LlmConfig) {
136        self.llm = Some(llm);
137        self.llm_config = config;
138    }
139
140    /// True if this executor can handle the given stage.
141    pub fn has_implementation(&self, stage_id: &StageId) -> bool {
142        matches!(
143            self.descriptions.get(&stage_id.0).map(|s| s.as_str()),
144            Some(
145                "Generate text completion using a language model"
146                    | "Generate a vector embedding for text"
147                    | "Classify text into one of the provided categories"
148                    | "Extract structured data from text according to a schema"
149                    | "Get detailed information about a stage by its ID"
150                    | "Search the stage store by semantic query"
151                    | "Check if one type is a structural subtype of another"
152                    | "Verify that a composition graph type-checks correctly"
153                    | "Register a new stage in the store"
154                    | "Retrieve the execution trace of a past composition"
155            )
156        )
157    }
158
159    // ── Dispatch ──────────────────────────────────────────────────────────────
160
161    fn dispatch(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
162        let desc = self
163            .descriptions
164            .get(&stage_id.0)
165            .map(|s| s.as_str())
166            .unwrap_or("");
167
168        match desc {
169            "Generate text completion using a language model" => self.llm_complete(stage_id, input),
170            "Generate a vector embedding for text" => self.llm_embed(stage_id, input),
171            "Classify text into one of the provided categories" => {
172                self.llm_classify(stage_id, input)
173            }
174            "Extract structured data from text according to a schema" => {
175                self.llm_extract(stage_id, input)
176            }
177            "Get detailed information about a stage by its ID" => {
178                self.stage_describe(stage_id, input)
179            }
180            "Search the stage store by semantic query" => self.store_search(stage_id, input),
181            "Check if one type is a structural subtype of another" => type_check(stage_id, input),
182            "Verify that a composition graph type-checks correctly" => {
183                self.composition_verify(stage_id, input)
184            }
185            "Register a new stage in the store" => {
186                // store_add requires mutable store access which executors don't hold.
187                // Use `noether compose` or the synthesis API to register new stages.
188                Err(ExecutionError::StageFailed {
189                    stage_id: stage_id.clone(),
190                    message: "store_add cannot be called inside a composition graph — use `noether compose` or the synthesis API to register new stages".into(),
191                })
192            }
193            "Retrieve the execution trace of a past composition" => {
194                // trace_read requires the TraceStore which executors don't hold.
195                // Use `noether trace <composition_id>` from the CLI.
196                Err(ExecutionError::StageFailed {
197                    stage_id: stage_id.clone(),
198                    message: "trace_read cannot be called inside a composition graph — use `noether trace <composition_id>` from the CLI".into(),
199                })
200            }
201            _ => Err(ExecutionError::StageNotFound(stage_id.clone())),
202        }
203    }
204
205    // ── LLM helpers ───────────────────────────────────────────────────────────
206
207    fn require_llm(&self, stage_id: &StageId) -> Result<&dyn LlmProvider, ExecutionError> {
208        self.llm.as_deref().ok_or_else(|| ExecutionError::StageFailed {
209            stage_id: stage_id.clone(),
210            message: "LLM provider not configured (set VERTEX_AI_PROJECT, VERTEX_AI_TOKEN, VERTEX_AI_LOCATION)".into(),
211        })
212    }
213
214    fn llm_complete(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
215        let llm = self.require_llm(stage_id)?;
216
217        let prompt = input["prompt"].as_str().unwrap_or("").to_string();
218        let model = input["model"]
219            .as_str()
220            .unwrap_or(&self.llm_config.model)
221            .to_string();
222        let max_tokens = input["max_tokens"]
223            .as_u64()
224            .map(|v| v as u32)
225            .unwrap_or(self.llm_config.max_tokens);
226        let temperature = input["temperature"]
227            .as_f64()
228            .map(|v| v as f32)
229            .unwrap_or(self.llm_config.temperature);
230        let system_opt = input["system"].as_str();
231
232        let mut messages = vec![];
233        if let Some(sys) = system_opt {
234            messages.push(Message::system(sys));
235        }
236        messages.push(Message::user(&prompt));
237
238        let cfg = LlmConfig {
239            model: model.clone(),
240            max_tokens,
241            temperature,
242        };
243
244        // LLM call deduplication: identical (model, prompt, system) calls within the same
245        // session return the cached response instead of making a redundant API call.
246        let dedup_key = {
247            use sha2::{Digest, Sha256};
248            let key_data = format!("{}:{}:{}", model, system_opt.unwrap_or(""), prompt);
249            hex::encode(Sha256::digest(key_data.as_bytes()))
250        };
251
252        {
253            let cache = self.llm_dedup_cache.lock().unwrap();
254            if let Some(cached) = cache.get(&dedup_key) {
255                let mut result = cached.clone();
256                result["from_llm_cache"] = json!(true);
257                return Ok(result);
258            }
259        }
260
261        let text = llm
262            .complete(&messages, &cfg)
263            .map_err(|e| ExecutionError::StageFailed {
264                stage_id: stage_id.clone(),
265                message: format!("LLM error: {e}"),
266            })?;
267
268        let tokens_used = text.split_whitespace().count() as u64;
269
270        let result = json!({
271            "text": text,
272            "tokens_used": tokens_used,
273            "model": model,
274            "from_llm_cache": false,
275        });
276
277        self.llm_dedup_cache
278            .lock()
279            .unwrap()
280            .insert(dedup_key, result.clone());
281
282        Ok(result)
283    }
284
285    fn llm_embed(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
286        let text = input["text"].as_str().unwrap_or("").to_string();
287        let model_override = input["model"].as_str().map(|s| s.to_string());
288
289        // Prefer real embedding provider when available.
290        if let Some(ep) = &self.embedding_provider {
291            let emb = ep.embed(&text).map_err(|e| ExecutionError::StageFailed {
292                stage_id: stage_id.clone(),
293                message: format!("embedding provider error: {e}"),
294            })?;
295            let dims = emb.len() as u64;
296            let model = model_override.unwrap_or_else(|| "embedding-model".into());
297            return Ok(json!({
298                "embedding": emb,
299                "dimensions": dims,
300                "model": model,
301            }));
302        }
303
304        // Fallback: ask the LLM to generate a JSON array of floats.
305        let llm = self.require_llm(stage_id)?;
306        let model = model_override.unwrap_or_else(|| "text-embedding-004".to_string());
307
308        let prompt = format!(
309            "Generate a compact 8-dimensional embedding vector for this text as a JSON array of floats: \"{text}\". Respond ONLY with a JSON array like [0.1, -0.2, ...]."
310        );
311        let messages = vec![
312            Message::system("You are an embedding model. Respond only with a JSON float array."),
313            Message::user(&prompt),
314        ];
315        let cfg = LlmConfig {
316            model: model.clone(),
317            max_tokens: 128,
318            temperature: 0.0,
319        };
320
321        let response = llm
322            .complete(&messages, &cfg)
323            .map_err(|e| ExecutionError::StageFailed {
324                stage_id: stage_id.clone(),
325                message: format!("LLM error: {e}"),
326            })?;
327
328        let embedding: Value =
329            extract_json_array(&response).ok_or_else(|| ExecutionError::StageFailed {
330                stage_id: stage_id.clone(),
331                message: format!("could not parse embedding from LLM response: {response:?}"),
332            })?;
333
334        let dims = embedding.as_array().map(|a| a.len()).unwrap_or(0) as u64;
335
336        Ok(json!({
337            "embedding": embedding,
338            "dimensions": dims,
339            "model": model,
340        }))
341    }
342
343    fn llm_classify(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
344        let llm = self.require_llm(stage_id)?;
345
346        let text = input["text"].as_str().unwrap_or("").to_string();
347        let model = input["model"]
348            .as_str()
349            .unwrap_or(&self.llm_config.model)
350            .to_string();
351        let categories: Vec<String> = input["categories"]
352            .as_array()
353            .map(|a| {
354                a.iter()
355                    .filter_map(|v| v.as_str())
356                    .map(|s| s.to_string())
357                    .collect()
358            })
359            .unwrap_or_default();
360
361        if categories.is_empty() {
362            return Err(ExecutionError::StageFailed {
363                stage_id: stage_id.clone(),
364                message: "categories list is empty".into(),
365            });
366        }
367
368        let cats_str = categories.join(", ");
369        let prompt = format!(
370            "Classify the following text into EXACTLY ONE of these categories: {cats_str}\n\nText: \"{text}\"\n\nRespond with ONLY valid JSON: {{\"category\": \"<one of the categories>\", \"confidence\": <0.0-1.0>}}"
371        );
372
373        let messages = vec![
374            Message::system(
375                "You are a text classifier. Always respond with valid JSON only. No explanation.",
376            ),
377            Message::user(&prompt),
378        ];
379        let cfg = LlmConfig {
380            model: model.clone(),
381            max_tokens: 64,
382            temperature: 0.0,
383        };
384
385        let response = llm
386            .complete(&messages, &cfg)
387            .map_err(|e| ExecutionError::StageFailed {
388                stage_id: stage_id.clone(),
389                message: format!("LLM error: {e}"),
390            })?;
391
392        let parsed: Value =
393            extract_json_object(&response).ok_or_else(|| ExecutionError::StageFailed {
394                stage_id: stage_id.clone(),
395                message: format!("could not parse classification JSON from: {response:?}"),
396            })?;
397
398        let category = parsed["category"].as_str().unwrap_or("").trim().to_string();
399        if !categories.contains(&category) {
400            return Err(ExecutionError::StageFailed {
401                stage_id: stage_id.clone(),
402                message: format!(
403                    "LLM returned unknown category {category:?}; expected one of: {cats_str}"
404                ),
405            });
406        }
407
408        let confidence = parsed["confidence"].as_f64().unwrap_or(1.0);
409
410        Ok(json!({
411            "category": category,
412            "confidence": confidence,
413            "model": model,
414        }))
415    }
416
417    fn llm_extract(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
418        let llm = self.require_llm(stage_id)?;
419
420        let text = input["text"].as_str().unwrap_or("").to_string();
421        let model = input["model"]
422            .as_str()
423            .unwrap_or(&self.llm_config.model)
424            .to_string();
425        let schema = input.get("schema").cloned().unwrap_or(json!({}));
426        let schema_str = serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string());
427
428        let prompt = format!(
429            "Extract structured data from the following text.\nSchema: {schema_str}\nText: \"{text}\"\n\nRespond with ONLY a valid JSON object matching the schema. No explanation."
430        );
431
432        let messages = vec![
433            Message::system(
434                "You are a structured data extractor. Always respond with valid JSON only.",
435            ),
436            Message::user(&prompt),
437        ];
438        let cfg = LlmConfig {
439            model: model.clone(),
440            max_tokens: 512,
441            temperature: 0.0,
442        };
443
444        let response = llm
445            .complete(&messages, &cfg)
446            .map_err(|e| ExecutionError::StageFailed {
447                stage_id: stage_id.clone(),
448                message: format!("LLM error: {e}"),
449            })?;
450
451        let extracted =
452            extract_json_object(&response).ok_or_else(|| ExecutionError::StageFailed {
453                stage_id: stage_id.clone(),
454                message: format!("could not parse extraction JSON from: {response:?}"),
455            })?;
456
457        Ok(json!({
458            "extracted": extracted,
459            "model": model,
460        }))
461    }
462
463    // ── Store-aware stages ────────────────────────────────────────────────────
464
465    fn stage_describe(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
466        let id = input["id"].as_str().unwrap_or("").to_string();
467
468        let cached = self
469            .stage_cache
470            .iter()
471            .find(|s| s.id == id || s.id.starts_with(&id))
472            .ok_or_else(|| ExecutionError::StageFailed {
473                stage_id: stage_id.clone(),
474                message: format!("stage {id:?} not found"),
475            })?;
476
477        Ok(json!({
478            "id": cached.id,
479            "description": cached.description,
480            "input": cached.input_display,
481            "output": cached.output_display,
482            "effects": cached.effects,
483            "lifecycle": cached.lifecycle,
484            "examples_count": cached.examples_count,
485        }))
486    }
487
488    /// Search the stage store by semantic query.
489    ///
490    /// When an `EmbeddingProvider` has been attached via `with_embedding()`, uses
491    /// cosine similarity over pre-computed stage embeddings for real semantic search.
492    /// Falls back to case-insensitive substring match when no embedding provider is present.
493    fn store_search(&self, _stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
494        let query = input["query"].as_str().unwrap_or("");
495        let limit = input["limit"].as_u64().unwrap_or(10) as usize;
496
497        if let Some(ep) = &self.embedding_provider {
498            // Semantic search via cosine similarity
499            if let Ok(query_emb) = ep.embed(query) {
500                let mut scored: Vec<(f32, &CachedStage)> = self
501                    .stage_cache
502                    .iter()
503                    .filter_map(|s| {
504                        self.stage_embeddings
505                            .get(&s.id)
506                            .map(|emb| (cosine_similarity(&query_emb, emb), s))
507                    })
508                    .collect();
509
510                scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
511
512                let results: Vec<Value> = scored
513                    .into_iter()
514                    .take(limit)
515                    .map(|(score, s)| {
516                        json!({
517                            "id": s.id,
518                            "description": s.description,
519                            "input": s.input_display,
520                            "output": s.output_display,
521                            "score": score,
522                        })
523                    })
524                    .collect();
525
526                return Ok(Value::Array(results));
527            }
528        }
529
530        // Substring fallback
531        let query_lc = query.to_lowercase();
532        let results: Vec<Value> = self
533            .stage_cache
534            .iter()
535            .filter(|s| {
536                s.description.to_lowercase().contains(&query_lc)
537                    || s.input_display.to_lowercase().contains(&query_lc)
538                    || s.output_display.to_lowercase().contains(&query_lc)
539            })
540            .take(limit)
541            .map(|s| {
542                json!({
543                    "id": s.id,
544                    "description": s.description,
545                    "input": s.input_display,
546                    "output": s.output_display,
547                    "score": 1.0,
548                })
549            })
550            .collect();
551
552        Ok(Value::Array(results))
553    }
554
555    /// Verify a composition by resolving its stage IDs and type-checking sequential chains.
556    ///
557    /// Input: `{ stages: List<Text>, operators: List<Text> }`
558    /// Output: `{ valid: Bool, errors: List<Text>, warnings: List<Text> }`
559    fn composition_verify(
560        &self,
561        stage_id: &StageId,
562        input: &Value,
563    ) -> Result<Value, ExecutionError> {
564        let stage_ids: Vec<&str> = input["stages"]
565            .as_array()
566            .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
567            .unwrap_or_default();
568
569        let operators: Vec<&str> = input["operators"]
570            .as_array()
571            .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
572            .unwrap_or_default();
573
574        let mut errors: Vec<String> = vec![];
575        let mut warnings: Vec<String> = vec![];
576
577        if stage_ids.is_empty() {
578            warnings.push("empty composition".into());
579            return Ok(json!({ "valid": true, "errors": errors, "warnings": warnings }));
580        }
581
582        // Validate operator names
583        let valid_ops = [
584            "sequential",
585            "parallel",
586            "branch",
587            "fanout",
588            "merge",
589            "retry",
590        ];
591        for op in &operators {
592            let op_lc = op.to_lowercase();
593            if !valid_ops.contains(&op_lc.as_str()) {
594                errors.push(format!("unknown operator: {op}"));
595            }
596        }
597
598        // Resolve stage IDs and build a lookup by id for type-checking
599        let id_to_cache: HashMap<&str, &CachedStage> = self
600            .stage_cache
601            .iter()
602            .map(|s| (s.id.as_str(), s))
603            .collect();
604
605        let mut resolved_stages: Vec<&CachedStage> = vec![];
606        for sid in &stage_ids {
607            match id_to_cache.get(sid) {
608                Some(s) => {
609                    if s.lifecycle == "deprecated" {
610                        warnings.push(format!("stage {} ({}) is deprecated", sid, s.description));
611                    }
612                    if s.lifecycle == "tombstone" {
613                        errors.push(format!(
614                            "stage {} is a tombstone and cannot be executed",
615                            sid
616                        ));
617                    }
618                    resolved_stages.push(s);
619                }
620                None => {
621                    errors.push(format!("stage {sid} not found in store"));
622                }
623            }
624        }
625
626        // For sequential compositions: type-check consecutive pairs.
627        // We parse the stored display strings back to NType for comparison.
628        if operators.iter().any(|op| op.to_lowercase() == "sequential") && resolved_stages.len() > 1
629        {
630            for i in 0..resolved_stages.len() - 1 {
631                let out_str = &resolved_stages[i].output_display;
632                let in_str = &resolved_stages[i + 1].input_display;
633
634                let out_type: Option<NType> = serde_json::from_str(&format!("\"{}\"", out_str))
635                    .ok()
636                    .or_else(|| parse_ntype_display(out_str));
637                let in_type: Option<NType> = serde_json::from_str(&format!("\"{}\"", in_str))
638                    .ok()
639                    .or_else(|| parse_ntype_display(in_str));
640
641                if let (Some(out), Some(inp)) = (out_type, in_type) {
642                    use noether_core::types::{is_subtype_of, TypeCompatibility};
643                    if let TypeCompatibility::Incompatible(reason) = is_subtype_of(&out, &inp) {
644                        errors.push(format!(
645                            "type mismatch between stages {} and {}: {} is not compatible with {} ({})",
646                            stage_ids[i], stage_ids[i + 1], out_str, in_str, reason
647                        ));
648                    }
649                }
650                // If we can't parse types, we skip the check rather than emitting a false error.
651            }
652        }
653
654        // Run the composition via the provided stage id for error context
655        let _ = stage_id;
656
657        let valid = errors.is_empty();
658        Ok(json!({
659            "valid": valid,
660            "errors": errors,
661            "warnings": warnings,
662        }))
663    }
664}
665
666impl StageExecutor for RuntimeExecutor {
667    fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
668        self.dispatch(stage_id, input)
669    }
670}
671
672// ── Pure helpers (no LLM / store state) ──────────────────────────────────────
673
674/// `type_check`: `{sub: NType JSON, sup: NType JSON}` → `{compatible: bool, reason: Text|Null}`
675fn type_check(stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
676    use noether_core::types::{is_subtype_of, TypeCompatibility};
677
678    let sub = parse_ntype_input(&input["sub"]).ok_or_else(|| ExecutionError::StageFailed {
679        stage_id: stage_id.clone(),
680        message: format!("could not parse sub type from: {}", input["sub"]),
681    })?;
682
683    let sup = parse_ntype_input(&input["sup"]).ok_or_else(|| ExecutionError::StageFailed {
684        stage_id: stage_id.clone(),
685        message: format!("could not parse sup type from: {}", input["sup"]),
686    })?;
687
688    match is_subtype_of(&sub, &sup) {
689        TypeCompatibility::Compatible => Ok(json!({"compatible": true, "reason": null})),
690        TypeCompatibility::Incompatible(reason) => {
691            Ok(json!({"compatible": false, "reason": format!("{reason}")}))
692        }
693    }
694}
695
696// ── Parsing helpers ───────────────────────────────────────────────────────────
697
698/// Parse an NType from either:
699/// - A JSON string like `"Text"`, `"Number"`, `"Bool"`, `"Any"`, `"Null"`, `"Bytes"`
700/// - A JSON object (the NType serde representation) like `{"kind": "Text"}`
701fn parse_ntype_input(v: &Value) -> Option<NType> {
702    if let Some(s) = v.as_str() {
703        match s {
704            "Text" => return Some(NType::Text),
705            "Number" => return Some(NType::Number),
706            "Bool" => return Some(NType::Bool),
707            "Any" => return Some(NType::Any),
708            "Null" => return Some(NType::Null),
709            "Bytes" => return Some(NType::Bytes),
710            _ => {}
711        }
712    }
713    serde_json::from_value(v.clone()).ok()
714}
715
716/// Parse a display string (e.g. "Text", "Number", "Any") into an NType.
717/// Used for type-checking in composition_verify.
718fn parse_ntype_display(s: &str) -> Option<NType> {
719    match s.trim() {
720        "Text" => Some(NType::Text),
721        "Number" => Some(NType::Number),
722        "Bool" => Some(NType::Bool),
723        "Any" => Some(NType::Any),
724        "Null" => Some(NType::Null),
725        "Bytes" => Some(NType::Bytes),
726        "VNode" => Some(NType::VNode),
727        _ => None,
728    }
729}
730
731/// Extract the first JSON array `[...]` found in a string.
732fn extract_json_array(s: &str) -> Option<Value> {
733    let start = s.find('[')?;
734    let end = s.rfind(']').map(|i| i + 1)?;
735    serde_json::from_str(&s[start..end]).ok()
736}
737
738/// Extract the first JSON object `{...}` found in a string.
739fn extract_json_object(s: &str) -> Option<Value> {
740    let start = s.find('{')?;
741    let end = s.rfind('}').map(|i| i + 1)?;
742    serde_json::from_str(&s[start..end]).ok()
743}
744
745// ── Tests ─────────────────────────────────────────────────────────────────────
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750    use noether_core::stdlib::load_stdlib;
751    use noether_store::MemoryStore;
752
753    fn stdlib_runtime() -> RuntimeExecutor {
754        let mut store = MemoryStore::new();
755        for s in load_stdlib() {
756            let _ = store.put(s);
757        }
758        RuntimeExecutor::from_store(&store)
759    }
760
761    #[test]
762    fn type_check_compatible() {
763        let rt = stdlib_runtime();
764        let id = rt
765            .descriptions
766            .iter()
767            .find(|(_, v)| v.contains("structural subtype"))
768            .map(|(k, _)| StageId(k.clone()))
769            .unwrap();
770        let result = rt
771            .execute(&id, &json!({"sub": "Text", "sup": "Text"}))
772            .unwrap();
773        assert_eq!(result["compatible"], json!(true));
774        assert_eq!(result["reason"], json!(null));
775    }
776
777    #[test]
778    fn type_check_incompatible() {
779        let rt = stdlib_runtime();
780        let id = rt
781            .descriptions
782            .iter()
783            .find(|(_, v)| v.contains("structural subtype"))
784            .map(|(k, _)| StageId(k.clone()))
785            .unwrap();
786        let result = rt
787            .execute(&id, &json!({"sub": "Text", "sup": "Number"}))
788            .unwrap();
789        assert_eq!(result["compatible"], json!(false));
790        assert!(result["reason"].is_string());
791    }
792
793    #[test]
794    fn stage_describe_includes_effects() {
795        let rt = stdlib_runtime();
796        let describe_id = rt
797            .descriptions
798            .iter()
799            .find(|(_, v)| v.contains("Get detailed information"))
800            .map(|(k, _)| StageId(k.clone()))
801            .unwrap();
802        let to_text_id = rt
803            .descriptions
804            .iter()
805            .find(|(_, v)| v.contains("Convert any value to its text"))
806            .map(|(k, _)| k.clone())
807            .unwrap();
808
809        let result = rt
810            .execute(&describe_id, &json!({"id": to_text_id}))
811            .unwrap();
812        assert_eq!(result["id"], json!(to_text_id));
813        assert!(result["description"].as_str().unwrap().contains("text"));
814        // effects is now a list
815        assert!(result["effects"].is_array(), "effects should be an array");
816        assert!(result["examples_count"].as_u64().unwrap() > 0);
817    }
818
819    #[test]
820    fn store_search_finds_stages() {
821        let rt = stdlib_runtime();
822        let search_id = rt
823            .descriptions
824            .iter()
825            .find(|(_, v)| v.contains("Search the stage store"))
826            .map(|(k, _)| StageId(k.clone()))
827            .unwrap();
828        let result = rt
829            .execute(&search_id, &json!({"query": "sort", "limit": 5}))
830            .unwrap();
831        let hits = result.as_array().unwrap();
832        assert!(!hits.is_empty());
833        assert!(hits
834            .iter()
835            .any(|h| h["description"].as_str().unwrap_or("").contains("Sort")));
836    }
837
838    #[test]
839    fn store_search_with_embedding_provider() {
840        use crate::index::embedding::MockEmbeddingProvider;
841        let mut store = MemoryStore::new();
842        for s in load_stdlib() {
843            let _ = store.put(s);
844        }
845        let rt = RuntimeExecutor::from_store(&store)
846            .with_embedding(Box::new(MockEmbeddingProvider::new(32)));
847
848        let search_id = rt
849            .descriptions
850            .iter()
851            .find(|(_, v)| v.contains("Search the stage store"))
852            .map(|(k, _)| StageId(k.clone()))
853            .unwrap();
854        let result = rt
855            .execute(&search_id, &json!({"query": "sort list", "limit": 10}))
856            .unwrap();
857        let hits = result.as_array().unwrap();
858        assert!(!hits.is_empty());
859        // All scores should be in [0, 1]
860        for h in hits {
861            let score = h["score"].as_f64().unwrap();
862            assert!((0.0..=1.0).contains(&score), "score {score} out of range");
863        }
864    }
865
866    #[test]
867    fn composition_verify_valid_stages() {
868        let rt = stdlib_runtime();
869        let verify_id = rt
870            .descriptions
871            .iter()
872            .find(|(_, v)| v.contains("Verify that a composition graph"))
873            .map(|(k, _)| StageId(k.clone()))
874            .unwrap();
875
876        // Two real stage IDs from the store
877        let ids: Vec<String> = rt
878            .stage_cache
879            .iter()
880            .take(2)
881            .map(|s| s.id.clone())
882            .collect();
883
884        let result = rt
885            .execute(
886                &verify_id,
887                &json!({
888                    "stages": ids,
889                    "operators": ["sequential"]
890                }),
891            )
892            .unwrap();
893        // Should succeed even if types don't match (warnings, not errors for this)
894        assert!(result["errors"].is_array());
895        assert!(result["warnings"].is_array());
896    }
897
898    #[test]
899    fn composition_verify_unknown_stage_is_error() {
900        let rt = stdlib_runtime();
901        let verify_id = rt
902            .descriptions
903            .iter()
904            .find(|(_, v)| v.contains("Verify that a composition graph"))
905            .map(|(k, _)| StageId(k.clone()))
906            .unwrap();
907
908        let result = rt
909            .execute(
910                &verify_id,
911                &json!({
912                    "stages": ["nonexistent-stage-id"],
913                    "operators": []
914                }),
915            )
916            .unwrap();
917        assert_eq!(result["valid"], json!(false));
918        assert!(result["errors"]
919            .as_array()
920            .unwrap()
921            .iter()
922            .any(|e| { e.as_str().unwrap_or("").contains("not found") }));
923    }
924
925    #[test]
926    fn llm_complete_fails_gracefully_without_llm() {
927        let rt = stdlib_runtime();
928        let llm_id = rt
929            .descriptions
930            .iter()
931            .find(|(_, v)| v.contains("Generate text completion"))
932            .map(|(k, _)| StageId(k.clone()))
933            .unwrap();
934        let result = rt.execute(
935            &llm_id,
936            &json!({"prompt": "Hello", "model": null, "max_tokens": null, "temperature": null, "system": null}),
937        );
938        assert!(result.is_err());
939        let msg = result.unwrap_err().to_string();
940        assert!(
941            msg.contains("LLM provider not configured"),
942            "expected config error, got: {msg}"
943        );
944    }
945
946    #[test]
947    fn llm_embed_uses_embedding_provider_when_available() {
948        use crate::index::embedding::MockEmbeddingProvider;
949        let mut store = MemoryStore::new();
950        for s in load_stdlib() {
951            let _ = store.put(s);
952        }
953        let rt = RuntimeExecutor::from_store(&store)
954            .with_embedding(Box::new(MockEmbeddingProvider::new(16)));
955
956        let embed_id = rt
957            .descriptions
958            .iter()
959            .find(|(_, v)| v.contains("Generate a vector embedding"))
960            .map(|(k, _)| StageId(k.clone()))
961            .unwrap();
962
963        let result = rt
964            .execute(&embed_id, &json!({"text": "hello world", "model": null}))
965            .unwrap();
966        assert_eq!(result["dimensions"], json!(16u64));
967        assert_eq!(result["embedding"].as_array().unwrap().len(), 16);
968    }
969
970    /// Verify the `Mutex<HashMap>` LLM dedup cache is safe under concurrent access.
971    #[test]
972    fn llm_dedup_cache_concurrent_access() {
973        use crate::llm::MockLlmProvider;
974        use std::sync::Arc;
975
976        let mock_response = r#"{"category":"positive","confidence":0.99,"model":"mock"}"#;
977
978        let mut store = MemoryStore::new();
979        for s in load_stdlib() {
980            let _ = store.put(s);
981        }
982
983        let rt = RuntimeExecutor::from_store(&store).with_llm(
984            Box::new(MockLlmProvider::new(mock_response)),
985            LlmConfig::default(),
986        );
987        let rt = Arc::new(rt);
988
989        let classify_id = rt
990            .descriptions
991            .iter()
992            .find(|(_, v)| v.contains("Classify text into one of"))
993            .map(|(k, _)| StageId(k.clone()))
994            .expect("classify_text stage not found");
995
996        let input = serde_json::json!({
997            "text": "I love this product",
998            "categories": ["positive", "negative", "neutral"],
999            "model": null
1000        });
1001
1002        let results: Vec<_> = std::thread::scope(|s| {
1003            let handles: Vec<_> = (0..16)
1004                .map(|_| {
1005                    let rt = Arc::clone(&rt);
1006                    let id = classify_id.clone();
1007                    let inp = input.clone();
1008                    s.spawn(move || rt.execute(&id, &inp))
1009                })
1010                .collect();
1011            handles.into_iter().map(|h| h.join().unwrap()).collect()
1012        });
1013
1014        assert_eq!(results.len(), 16);
1015        let first = results[0].as_ref().expect("first result must be Ok");
1016        for (i, r) in results.iter().enumerate() {
1017            let val = r
1018                .as_ref()
1019                .unwrap_or_else(|e| panic!("thread {i} failed: {e}"));
1020            assert_eq!(
1021                val["category"], first["category"],
1022                "thread {i} returned different category"
1023            );
1024        }
1025        assert_eq!(first["category"].as_str().unwrap(), "positive");
1026    }
1027}