Skip to main content

kronroe_wasm/
wasm_bindings.rs

1//! Kronroe WASM — browser-compatible temporal graph database.
2//!
3//! This crate wraps [`kronroe_agent_memory::AgentMemory`] for WebAssembly
4//! environments. It uses an in-memory storage backend (no file I/O), making it
5//! suitable for browser-based demos and agent workflows.
6//!
7//! # Usage (JavaScript)
8//!
9//! ```js
10//! import init, { WasmGraph } from 'kronroe-wasm';
11//!
12//! await init();
13//! const graph = WasmGraph.open();
14//! const factId = graph.assert_fact("alice", "works_at", "Acme");
15//! const facts = graph.current_facts("alice", "works_at");
16//! console.log(JSON.parse(facts));
17//! ```
18
19use chrono::{DateTime, Utc};
20use kronroe::{FactId, Value};
21use kronroe_agent_memory::{AgentMemory, AssertParams, RecallOptions, RecallScore};
22use serde_json::json;
23use serde_json::Value as JsonValue;
24use wasm_bindgen::prelude::*;
25
26#[cfg(feature = "hybrid")]
27use kronroe::{TemporalIntent, TemporalOperator};
28
29// ---------------------------------------------------------------------------
30// Error handling
31// ---------------------------------------------------------------------------
32
33/// Convert KronroeError to a JsValue for wasm-bindgen.
34fn to_js_err(e: kronroe::KronroeError) -> JsValue {
35    JsValue::from_str(&e.to_string())
36}
37
38fn parse_valid_from(iso: &str) -> Result<DateTime<Utc>, JsValue> {
39    iso.parse::<DateTime<Utc>>()
40        .map_err(|e: chrono::ParseError| JsValue::from_str(&e.to_string()))
41}
42
43fn parse_embedding(embedding: Option<Vec<f64>>) -> Result<Option<Vec<f32>>, JsValue> {
44    let Some(embedding) = embedding else {
45        return Ok(None);
46    };
47    if embedding.is_empty() {
48        return Err(JsValue::from_str("query_embedding must not be empty"));
49    }
50
51    let mut out = Vec::with_capacity(embedding.len());
52    for value in embedding {
53        if !value.is_finite() {
54            return Err(JsValue::from_str("query_embedding values must be finite"));
55        }
56        let narrowed = value as f32;
57        if !narrowed.is_finite() {
58            return Err(JsValue::from_str(
59                "query_embedding value overflows f32 range",
60            ));
61        }
62        out.push(narrowed);
63    }
64    Ok(Some(out))
65}
66
67#[cfg(feature = "hybrid")]
68fn parse_temporal_intent(raw: Option<String>) -> Result<Option<TemporalIntent>, JsValue> {
69    let Some(raw) = raw else {
70        return Ok(None);
71    };
72    let parsed = match raw.as_str() {
73        "timeless" => TemporalIntent::Timeless,
74        "current_state" => TemporalIntent::CurrentState,
75        "historical_point" => TemporalIntent::HistoricalPoint,
76        "historical_interval" => TemporalIntent::HistoricalInterval,
77        _ => {
78            return Err(JsValue::from_str(
79                "invalid temporal_intent: expected timeless|current_state|historical_point|historical_interval",
80            ));
81        }
82    };
83    Ok(Some(parsed))
84}
85
86#[cfg(feature = "hybrid")]
87fn parse_temporal_operator(raw: Option<String>) -> Result<Option<TemporalOperator>, JsValue> {
88    let Some(raw) = raw else {
89        return Ok(None);
90    };
91    let parsed = match raw.as_str() {
92        "current" => TemporalOperator::Current,
93        "as_of" => TemporalOperator::AsOf,
94        "before" => TemporalOperator::Before,
95        "by" => TemporalOperator::By,
96        "during" => TemporalOperator::During,
97        "after" => TemporalOperator::After,
98        "unknown" => TemporalOperator::Unknown,
99        _ => {
100            return Err(JsValue::from_str(
101                "invalid temporal_operator: expected current|as_of|before|by|during|after|unknown",
102            ));
103        }
104    };
105    Ok(Some(parsed))
106}
107
108fn recall_score_payload(score: &RecallScore) -> JsonValue {
109    match score {
110        RecallScore::Hybrid {
111            rrf_score,
112            text_contrib,
113            vector_contrib,
114            confidence,
115            effective_confidence,
116            ..
117        } => json!({
118            "kind": "hybrid",
119            "rrf_score": rrf_score,
120            "text_contrib": text_contrib,
121            "vector_contrib": vector_contrib,
122            "confidence": confidence,
123            "effective_confidence": effective_confidence,
124        }),
125        RecallScore::TextOnly {
126            rank,
127            bm25_score,
128            confidence,
129            effective_confidence,
130            ..
131        } => json!({
132            "kind": "text_only",
133            "rank": rank,
134            "bm25_score": bm25_score,
135            "confidence": confidence,
136            "effective_confidence": effective_confidence,
137        }),
138        _ => json!({
139            "kind": "unsupported",
140            "warning": "RecallScore variant not yet supported in wasm bindings",
141        }),
142    }
143}
144
145fn extract_source(source: Option<String>) -> Option<String> {
146    source.and_then(|source| {
147        if source.is_empty() {
148            None
149        } else {
150            Some(source)
151        }
152    })
153}
154
155// ---------------------------------------------------------------------------
156// WasmGraph — the public API
157// ---------------------------------------------------------------------------
158
159/// An in-memory AgentMemory store for browser environments.
160///
161/// All data lives in memory and is lost when the instance is dropped.
162/// This is designed for demos, playgrounds, and ephemeral workloads.
163#[wasm_bindgen]
164pub struct WasmGraph {
165    inner: AgentMemory,
166}
167
168#[wasm_bindgen]
169impl WasmGraph {
170    /// Create a new in-memory AgentMemory instance.
171    #[wasm_bindgen(constructor)]
172    pub fn open() -> Result<WasmGraph, JsValue> {
173        let inner = AgentMemory::open_in_memory().map_err(to_js_err)?;
174        Ok(WasmGraph { inner })
175    }
176
177    /// Assert a new fact and return its ID.
178    ///
179    /// The object is stored as a text value. For typed values (number,
180    /// boolean, entity reference), use typed methods.
181    #[wasm_bindgen]
182    pub fn assert_fact(
183        &self,
184        subject: &str,
185        predicate: &str,
186        object: &str,
187    ) -> Result<String, JsValue> {
188        let id = self
189            .inner
190            .assert(subject, predicate, object)
191            .map_err(to_js_err)?;
192        Ok(id.to_string())
193    }
194
195    /// Assert a fact with a specific valid_from timestamp (ISO 8601).
196    #[wasm_bindgen]
197    pub fn assert_fact_at(
198        &self,
199        subject: &str,
200        predicate: &str,
201        object: &str,
202        valid_from_iso: &str,
203    ) -> Result<String, JsValue> {
204        let valid_from = parse_valid_from(valid_from_iso)?;
205        let id = self
206            .inner
207            .assert_with_params(subject, predicate, object, AssertParams { valid_from })
208            .map_err(to_js_err)?;
209        Ok(id.to_string())
210    }
211
212    /// Assert a numeric fact.
213    #[wasm_bindgen]
214    pub fn assert_number_fact(
215        &self,
216        subject: &str,
217        predicate: &str,
218        value: f64,
219    ) -> Result<String, JsValue> {
220        let id = self
221            .inner
222            .assert(subject, predicate, Value::Number(value))
223            .map_err(to_js_err)?;
224        Ok(id.to_string())
225    }
226
227    /// Assert a boolean fact.
228    #[wasm_bindgen]
229    pub fn assert_boolean_fact(
230        &self,
231        subject: &str,
232        predicate: &str,
233        value: bool,
234    ) -> Result<String, JsValue> {
235        let id = self
236            .inner
237            .assert(subject, predicate, Value::Boolean(value))
238            .map_err(to_js_err)?;
239        Ok(id.to_string())
240    }
241
242    /// Assert an entity reference fact (graph edge).
243    #[wasm_bindgen]
244    pub fn assert_entity_fact(
245        &self,
246        subject: &str,
247        predicate: &str,
248        entity: &str,
249    ) -> Result<String, JsValue> {
250        let id = self
251            .inner
252            .assert(subject, predicate, Value::Entity(entity.to_string()))
253            .map_err(to_js_err)?;
254        Ok(id.to_string())
255    }
256
257    /// Assert a numeric fact with a specific valid_from timestamp (ISO 8601).
258    #[wasm_bindgen]
259    pub fn assert_number_fact_at(
260        &self,
261        subject: &str,
262        predicate: &str,
263        value: f64,
264        valid_from_iso: &str,
265    ) -> Result<String, JsValue> {
266        let valid_from = parse_valid_from(valid_from_iso)?;
267        let id = self
268            .inner
269            .assert_with_params(
270                subject,
271                predicate,
272                Value::Number(value),
273                AssertParams { valid_from },
274            )
275            .map_err(to_js_err)?;
276        Ok(id.to_string())
277    }
278
279    /// Assert a boolean fact with a specific valid_from timestamp (ISO 8601).
280    #[wasm_bindgen]
281    pub fn assert_boolean_fact_at(
282        &self,
283        subject: &str,
284        predicate: &str,
285        value: bool,
286        valid_from_iso: &str,
287    ) -> Result<String, JsValue> {
288        let valid_from = parse_valid_from(valid_from_iso)?;
289        let id = self
290            .inner
291            .assert_with_params(
292                subject,
293                predicate,
294                Value::Boolean(value),
295                AssertParams { valid_from },
296            )
297            .map_err(to_js_err)?;
298        Ok(id.to_string())
299    }
300
301    /// Assert an entity reference fact with a specific valid_from timestamp (ISO 8601).
302    #[wasm_bindgen]
303    pub fn assert_entity_fact_at(
304        &self,
305        subject: &str,
306        predicate: &str,
307        entity: &str,
308        valid_from_iso: &str,
309    ) -> Result<String, JsValue> {
310        let valid_from = parse_valid_from(valid_from_iso)?;
311        let id = self
312            .inner
313            .assert_with_params(
314                subject,
315                predicate,
316                Value::Entity(entity.to_string()),
317                AssertParams { valid_from },
318            )
319            .map_err(to_js_err)?;
320        Ok(id.to_string())
321    }
322
323    /// Assert a fact with confidence, optionally attaching a source marker.
324    #[wasm_bindgen]
325    pub fn assert_with_confidence(
326        &self,
327        subject: &str,
328        predicate: &str,
329        object: &str,
330        confidence: f64,
331        source: Option<String>,
332    ) -> Result<String, JsValue> {
333        if !confidence.is_finite() || !(0.0..=1.0).contains(&confidence) {
334            return Err(JsValue::from_str(
335                "confidence must be a finite number in [0.0, 1.0]",
336            ));
337        }
338        let confidence = confidence as f32;
339
340        let id = match extract_source(source) {
341            Some(source) => self
342                .inner
343                .assert_with_source(subject, predicate, object, confidence, &source),
344            None => self
345                .inner
346                .assert_with_confidence(subject, predicate, object, confidence),
347        }
348        .map_err(to_js_err)?;
349
350        Ok(id.to_string())
351    }
352
353    /// Correct a fact by ID.
354    #[wasm_bindgen]
355    pub fn correct_fact(&self, fact_id: &str, new_object: &str) -> Result<String, JsValue> {
356        let fact_id = FactId(fact_id.to_string());
357        let new_id = self
358            .inner
359            .correct_fact(&fact_id, new_object.to_string())
360            .map_err(to_js_err)?;
361        Ok(new_id.to_string())
362    }
363
364    /// Get all currently valid facts for (subject, predicate) as JSON.
365    #[wasm_bindgen]
366    pub fn current_facts(&self, subject: &str, predicate: &str) -> Result<String, JsValue> {
367        let facts = self
368            .inner
369            .current_facts(subject, predicate)
370            .map_err(to_js_err)?;
371        serde_json::to_string(&facts).map_err(|e| JsValue::from_str(&e.to_string()))
372    }
373
374    /// Get facts valid at a specific point in time (ISO 8601) as JSON.
375    #[wasm_bindgen]
376    pub fn facts_at(
377        &self,
378        subject: &str,
379        predicate: &str,
380        at_iso: &str,
381    ) -> Result<String, JsValue> {
382        let at = parse_valid_from(at_iso)?;
383        let facts = self
384            .inner
385            .facts_about_at(subject, predicate, at)
386            .map_err(to_js_err)?;
387        serde_json::to_string(&facts).map_err(|e| JsValue::from_str(&e.to_string()))
388    }
389
390    /// Get every fact ever recorded about an entity as JSON.
391    #[wasm_bindgen]
392    pub fn all_facts_about(&self, subject: &str) -> Result<String, JsValue> {
393        let facts = self.inner.facts_about(subject).map_err(to_js_err)?;
394        serde_json::to_string(&facts).map_err(|e| JsValue::from_str(&e.to_string()))
395    }
396
397    /// Alias for `all_facts_about`.
398    #[wasm_bindgen]
399    pub fn facts_about(&self, subject: &str) -> Result<String, JsValue> {
400        self.all_facts_about(subject)
401    }
402
403    /// Recall current facts for a query as JSON.
404    #[wasm_bindgen]
405    pub fn recall(
406        &self,
407        query: &str,
408        query_embedding: Option<Vec<f64>>,
409        limit: usize,
410    ) -> Result<String, JsValue> {
411        #[cfg(not(feature = "hybrid"))]
412        if query_embedding.is_some() {
413            return Err(JsValue::from_str(
414                "query_embedding is unavailable without the `hybrid` feature",
415            ));
416        }
417
418        let embedding = parse_embedding(query_embedding)?;
419        let facts = self
420            .inner
421            .recall(query, embedding.as_deref(), limit)
422            .map_err(to_js_err)?;
423        serde_json::to_string(&facts).map_err(|e| JsValue::from_str(&e.to_string()))
424    }
425
426    /// Recall facts with score metadata as JSON.
427    #[allow(clippy::too_many_arguments)]
428    #[wasm_bindgen]
429    pub fn recall_scored(
430        &self,
431        query: &str,
432        limit: usize,
433        query_embedding: Option<Vec<f64>>,
434        min_confidence: Option<f64>,
435        confidence_filter_mode: Option<String>,
436        max_scored_rows: Option<usize>,
437        use_hybrid: bool,
438        temporal_intent: Option<String>,
439        temporal_operator: Option<String>,
440    ) -> Result<String, JsValue> {
441        let embedding = parse_embedding(query_embedding)?;
442
443        #[cfg(not(feature = "hybrid"))]
444        if embedding.is_some()
445            || use_hybrid
446            || temporal_intent.is_some()
447            || temporal_operator.is_some()
448        {
449            return Err(JsValue::from_str(
450                "hybrid controls are unavailable without the `hybrid` feature",
451            ));
452        }
453
454        #[cfg(feature = "hybrid")]
455        if embedding.is_none()
456            && (use_hybrid || temporal_intent.is_some() || temporal_operator.is_some())
457        {
458            return Err(JsValue::from_str(
459                "query_embedding is required for hybrid/temporal controls",
460            ));
461        }
462
463        let mut opts = RecallOptions::new(query).with_limit(limit);
464        if let Some(embedding) = embedding.as_deref() {
465            opts = opts.with_embedding(embedding);
466            #[cfg(feature = "hybrid")]
467            if use_hybrid {
468                opts = opts.with_hybrid(true);
469            }
470        }
471
472        if confidence_filter_mode.is_some() && min_confidence.is_none() {
473            return Err(JsValue::from_str(
474                "confidence_filter_mode requires min_confidence",
475            ));
476        }
477
478        if let Some(min) = min_confidence {
479            if !min.is_finite() {
480                return Err(JsValue::from_str(
481                    "min_confidence/confidence must be finite",
482                ));
483            }
484            let mode = confidence_filter_mode
485                .as_deref()
486                .unwrap_or("base")
487                .to_ascii_lowercase();
488            if mode == "base" {
489                opts = opts.with_min_confidence(min as f32);
490            } else if mode == "effective" {
491                #[cfg(feature = "uncertainty")]
492                {
493                    opts = opts.with_min_effective_confidence(min as f32);
494                }
495                #[cfg(not(feature = "uncertainty"))]
496                {
497                    return Err(JsValue::from_str(
498                        "effective confidence filter requires the `uncertainty` feature",
499                    ));
500                }
501            } else {
502                return Err(JsValue::from_str(
503                    "confidence_filter_mode must be 'base' or 'effective'",
504                ));
505            }
506        }
507
508        if let Some(rows) = max_scored_rows {
509            opts = opts.with_max_scored_rows(rows);
510        }
511
512        #[cfg(feature = "hybrid")]
513        if let Some(intent) = parse_temporal_intent(temporal_intent)? {
514            opts = opts.with_temporal_intent(intent);
515        }
516        #[cfg(feature = "hybrid")]
517        if let Some(operator) = parse_temporal_operator(temporal_operator)? {
518            opts = opts.with_temporal_operator(operator);
519        }
520
521        let scored = self
522            .inner
523            .recall_scored_with_options(&opts)
524            .map_err(to_js_err)?;
525        let mut rows = Vec::with_capacity(scored.len());
526        for (fact, score) in scored {
527            rows.push(json!({
528                "fact": fact,
529                "score": recall_score_payload(&score),
530            }));
531        }
532        serde_json::to_string(&rows).map_err(|e| JsValue::from_str(&e.to_string()))
533    }
534
535    /// Build a memory-anchored prompt context from recalled facts.
536    #[wasm_bindgen]
537    pub fn assemble_context(
538        &self,
539        query: &str,
540        max_tokens: usize,
541        query_embedding: Option<Vec<f64>>,
542    ) -> Result<String, JsValue> {
543        if max_tokens == 0 {
544            return Err(JsValue::from_str("max_tokens must be >= 1"));
545        }
546
547        #[cfg(not(feature = "hybrid"))]
548        if query_embedding.is_some() {
549            return Err(JsValue::from_str(
550                "query_embedding is unavailable without the `hybrid` feature",
551            ));
552        }
553
554        let embedding = parse_embedding(query_embedding)?;
555        self.inner
556            .assemble_context(query, embedding.as_deref(), max_tokens)
557            .map_err(to_js_err)
558    }
559
560    /// Store an unstructured memory episode.
561    ///
562    /// Optional `idempotency_key` enables deduplicated retries.
563    #[wasm_bindgen]
564    pub fn remember(
565        &self,
566        text: &str,
567        episode_id: &str,
568        query_embedding: Option<Vec<f64>>,
569        idempotency_key: Option<String>,
570    ) -> Result<String, JsValue> {
571        if idempotency_key.is_some() && query_embedding.is_some() {
572            return Err(JsValue::from_str(
573                "idempotency_key is not supported with query_embedding in remember",
574            ));
575        }
576
577        #[cfg(not(feature = "hybrid"))]
578        if query_embedding.is_some() {
579            return Err(JsValue::from_str(
580                "query_embedding is unavailable without the `hybrid` feature",
581            ));
582        }
583
584        if let Some(key) = idempotency_key {
585            let id = self
586                .inner
587                .remember_idempotent(&key, text, episode_id)
588                .map_err(to_js_err)?;
589            return Ok(id.to_string());
590        }
591
592        let embedding = parse_embedding(query_embedding)?;
593        let id = if let Some(_embedding) = embedding {
594            #[cfg(feature = "hybrid")]
595            {
596                self.inner.remember(text, episode_id, Some(_embedding))
597            }
598            #[cfg(not(feature = "hybrid"))]
599            {
600                return Err(JsValue::from_str(
601                    "query_embedding is unavailable without the `hybrid` feature",
602                ));
603            }
604        } else {
605            self.inner.remember(text, episode_id, None)
606        }
607        .map_err(to_js_err)?;
608        Ok(id.to_string())
609    }
610
611    /// Invalidate a fact by its ID at the current time.
612    #[wasm_bindgen]
613    pub fn invalidate_fact(&self, fact_id: &str) -> Result<(), JsValue> {
614        let id = FactId(fact_id.to_string());
615        self.inner.invalidate_fact(&id).map_err(to_js_err)
616    }
617}
618
619// ---------------------------------------------------------------------------
620// Tests
621// ---------------------------------------------------------------------------
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[test]
628    fn wasm_graph_basic_operations() {
629        let graph = WasmGraph::open().unwrap();
630
631        // Assert and retrieve
632        let id = graph.assert_fact("alice", "works_at", "Acme").unwrap();
633        assert!(!id.is_empty());
634
635        let json = graph.current_facts("alice", "works_at").unwrap();
636        assert!(json.contains("Acme"));
637        assert!(json.contains("alice"));
638
639        // All facts about
640        graph.assert_fact("alice", "has_role", "Engineer").unwrap();
641        let all = graph.all_facts_about("alice").unwrap();
642        assert!(all.contains("works_at"));
643        assert!(all.contains("has_role"));
644    }
645
646    #[test]
647    fn wasm_graph_typed_values() {
648        let graph = WasmGraph::open().unwrap();
649
650        graph.assert_number_fact("alice", "score", 0.95).unwrap();
651        graph.assert_boolean_fact("alice", "active", true).unwrap();
652        graph
653            .assert_entity_fact("alice", "employer", "acme_corp")
654            .unwrap();
655
656        let all = graph.all_facts_about("alice").unwrap();
657        assert!(all.contains("0.95"));
658        assert!(all.contains("true"));
659        assert!(all.contains("acme_corp"));
660    }
661
662    #[test]
663    fn wasm_graph_temporal_query() {
664        let graph = WasmGraph::open().unwrap();
665
666        graph
667            .assert_fact_at("alice", "works_at", "Acme", "2024-01-01T00:00:00Z")
668            .unwrap();
669
670        // Valid in March 2024
671        let facts = graph
672            .facts_at("alice", "works_at", "2024-03-01T00:00:00Z")
673            .unwrap();
674        assert!(facts.contains("Acme"));
675
676        // Not valid before January 2024
677        let empty = graph
678            .facts_at("alice", "works_at", "2023-06-01T00:00:00Z")
679            .unwrap();
680        assert!(!empty.contains("Acme"));
681    }
682
683    #[test]
684    fn wasm_graph_invalidation() {
685        let graph = WasmGraph::open().unwrap();
686
687        let id = graph.assert_fact("alice", "works_at", "Acme").unwrap();
688        graph.invalidate_fact(&id).unwrap();
689
690        let current = graph.current_facts("alice", "works_at").unwrap();
691        // Should be empty array — fact was invalidated
692        assert_eq!(current, "[]");
693    }
694
695    #[cfg(feature = "hybrid")]
696    #[test]
697    fn wasm_graph_recall_scored_respects_min_confidence() {
698        let graph = WasmGraph::open().unwrap();
699
700        graph
701            .remember(
702                "Alice joined Acme engineering.",
703                "ep-high",
704                Some(vec![1.0, 0.0, 0.0]),
705                None,
706            )
707            .unwrap();
708        graph
709            .remember(
710                "Alice likes hiking on weekends.",
711                "ep-low",
712                Some(vec![0.0, 1.0, 0.0]),
713                None,
714            )
715            .unwrap();
716
717        let rows_json = graph
718            .recall_scored(
719                "Acme",
720                10,
721                Some(vec![1.0, 0.0, 0.0]),
722                Some(0.5),
723                Some("base".to_string()),
724                None,
725                true,
726                None,
727                None,
728            )
729            .unwrap();
730        let rows: Vec<serde_json::Value> = serde_json::from_str(&rows_json).unwrap();
731        assert!(!rows.is_empty());
732
733        let score = &rows[0]["score"];
734        assert_eq!(score["kind"], "hybrid");
735        assert!(score["confidence"].as_f64().unwrap() >= 0.5);
736    }
737
738    #[test]
739    fn wasm_graph_remember_persists_episode_fact() {
740        let graph = WasmGraph::open().unwrap();
741
742        graph
743            .remember("Alice joined Acme as an engineer.", "ep-1", None, None)
744            .unwrap();
745
746        let all = graph.all_facts_about("ep-1").unwrap();
747        assert!(all.contains("memory"));
748        assert!(all.contains("Alice joined Acme as an engineer."));
749    }
750}