Skip to main content

redis_vl/extensions/
router.rs

1//! Semantic router extension types.
2//!
3//! [`SemanticRouter`](crate::SemanticRouter) classifies input text against
4//! predefined [`Route`](crate::Route)s using vector similarity. Each route is
5//! defined by a set of reference utterances; incoming text is embedded and
6//! matched against the closest references.
7//!
8//! The router supports per-route distance thresholds, configurable aggregation
9//! methods ([`DistanceAggregationMethod`](crate::DistanceAggregationMethod)),
10//! serialization to/from YAML and JSON, and `from_existing` reconnection to a
11//! previously created router index.
12
13use std::{collections::HashMap, path::Path, sync::Arc};
14
15use serde::{Deserialize, Serialize};
16use serde_json::{Map, Value, json};
17
18use crate::{
19    error::{Error, Result},
20    index::{QueryOutput, RedisConnectionInfo, SearchIndex},
21    query::{Vector, VectorRangeQuery},
22    schema::VectorDataType,
23    vectorizers::Vectorizer,
24};
25
26const ROUTER_REFERENCE_ID_FIELD: &str = "reference_id";
27const ROUTER_ROUTE_NAME_FIELD: &str = "route_name";
28const ROUTER_REFERENCE_FIELD: &str = "reference";
29const ROUTER_VECTOR_FIELD: &str = "vector";
30const DEFAULT_ROUTE_DISTANCE_THRESHOLD: f32 = 0.5;
31
32/// Aggregation method used to combine reference distances into a route score.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "lowercase")]
35pub enum DistanceAggregationMethod {
36    /// Average the distances of all matched route references.
37    Avg,
38    /// Use the minimum matched reference distance.
39    Min,
40    /// Sum all matched reference distances.
41    Sum,
42}
43
44/// Route definition used by the semantic router.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Route {
47    /// Route name.
48    pub name: String,
49    /// Reference utterances used to anchor the route.
50    pub references: Vec<String>,
51    /// Optional route metadata.
52    #[serde(default)]
53    pub metadata: Map<String, Value>,
54    /// Optional per-route threshold.
55    #[serde(default)]
56    pub distance_threshold: Option<f32>,
57}
58
59impl Route {
60    /// Creates a route with the provided name and references.
61    pub fn new(name: impl Into<String>, references: Vec<String>) -> Self {
62        Self {
63            name: name.into(),
64            references,
65            metadata: Map::new(),
66            distance_threshold: None,
67        }
68    }
69
70    fn effective_threshold(&self) -> f32 {
71        self.distance_threshold
72            .unwrap_or(DEFAULT_ROUTE_DISTANCE_THRESHOLD)
73    }
74}
75
76/// Route match result.
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78pub struct RouteMatch {
79    /// Matched route name.
80    pub name: Option<String>,
81    /// Calculated route distance.
82    pub distance: Option<f32>,
83}
84
85impl RouteMatch {
86    fn no_match() -> Self {
87        Self {
88            name: None,
89            distance: None,
90        }
91    }
92}
93
94/// Runtime router configuration.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct RoutingConfig {
97    /// Maximum number of routes returned by `route_many`.
98    pub max_k: usize,
99    /// Aggregation method used to score a route from its matched references.
100    pub aggregation_method: DistanceAggregationMethod,
101}
102
103impl Default for RoutingConfig {
104    fn default() -> Self {
105        Self {
106            max_k: 1,
107            aggregation_method: DistanceAggregationMethod::Avg,
108        }
109    }
110}
111
112/// Semantic router backed by a Redis Search vector index.
113#[derive(Clone)]
114pub struct SemanticRouter {
115    /// Router name.
116    pub name: String,
117    /// Redis connection settings.
118    pub connection: RedisConnectionInfo,
119    /// Registered routes.
120    pub routes: Vec<Route>,
121    /// Runtime routing configuration.
122    pub routing_config: RoutingConfig,
123    /// Vector element data type used for the index schema.
124    pub dtype: VectorDataType,
125    /// Underlying search index used for route references.
126    pub index: SearchIndex,
127    vectorizer: Arc<dyn Vectorizer>,
128    vector_dimensions: usize,
129}
130
131impl SemanticRouter {
132    /// Creates a new semantic router and loads the provided routes into Redis.
133    ///
134    /// Uses [`VectorDataType::Float32`] by default. For other data types, use
135    /// [`Self::new_with_options`].
136    pub fn new<V>(
137        name: impl Into<String>,
138        redis_url: impl Into<String>,
139        routes: Vec<Route>,
140        routing_config: RoutingConfig,
141        vectorizer: V,
142    ) -> Result<Self>
143    where
144        V: Vectorizer + 'static,
145    {
146        Self::new_with_options(
147            name,
148            redis_url,
149            routes,
150            routing_config,
151            vectorizer,
152            VectorDataType::Float32,
153            false,
154        )
155    }
156
157    /// Creates a new semantic router using the default HuggingFace local
158    /// vectorizer (`AllMiniLML6V2`).
159    ///
160    /// This convenience constructor requires no API key — the model runs
161    /// locally via ONNX Runtime and is downloaded from HuggingFace Hub on
162    /// first use.
163    ///
164    /// # Errors
165    ///
166    /// Returns an error if the model cannot be loaded or the index cannot be
167    /// created.
168    #[cfg(feature = "hf-local")]
169    pub fn with_default_vectorizer(
170        name: impl Into<String>,
171        redis_url: impl Into<String>,
172        routes: Vec<Route>,
173        routing_config: RoutingConfig,
174    ) -> Result<Self> {
175        let vectorizer = crate::vectorizers::HuggingFaceTextVectorizer::new(Default::default())?;
176        Self::new(name, redis_url, routes, routing_config, vectorizer)
177    }
178
179    /// Creates a new semantic router with explicit dtype and overwrite control.
180    pub fn new_with_options<V>(
181        name: impl Into<String>,
182        redis_url: impl Into<String>,
183        routes: Vec<Route>,
184        routing_config: RoutingConfig,
185        vectorizer: V,
186        dtype: VectorDataType,
187        overwrite: bool,
188    ) -> Result<Self>
189    where
190        V: Vectorizer + 'static,
191    {
192        if routes.is_empty() {
193            return Err(Error::InvalidInput(
194                "semantic router requires at least one route".to_owned(),
195            ));
196        }
197
198        let Some(first_reference) = routes
199            .iter()
200            .flat_map(|route| route.references.iter())
201            .next()
202            .cloned()
203        else {
204            return Err(Error::InvalidInput(
205                "semantic router routes require at least one reference".to_owned(),
206            ));
207        };
208
209        let vectorizer = Arc::new(vectorizer);
210        let probe_vector = vectorizer.embed(&first_reference)?;
211        if probe_vector.is_empty() {
212            return Err(Error::InvalidInput(
213                "router vectorizer produced an empty embedding".to_owned(),
214            ));
215        }
216
217        let name = name.into();
218        let connection = RedisConnectionInfo::new(redis_url);
219        let schema = router_schema(&name, probe_vector.len(), dtype);
220        let index = SearchIndex::from_json_value(schema, connection.redis_url.clone())?;
221        let existed = index.exists().unwrap_or(false);
222
223        // Validate schema compatibility with existing index (mirrors Python behavior)
224        if !overwrite && existed {
225            let existing_index = SearchIndex::from_existing(&name, connection.redis_url.clone())?;
226            if existing_index.schema().to_json_value()? != index.schema().to_json_value()? {
227                return Err(Error::InvalidInput(format!(
228                    "Existing index {name} schema does not match the user provided schema for the semantic router. \
229                     If you wish to overwrite the index schema, set overwrite=true during initialization."
230                )));
231            }
232        }
233
234        index.create_with_options(overwrite, false)?;
235
236        let semantic_router = Self {
237            name,
238            connection,
239            routes,
240            routing_config,
241            dtype,
242            index,
243            vectorizer,
244            vector_dimensions: probe_vector.len(),
245        };
246
247        if !existed || overwrite {
248            semantic_router.load_routes()?;
249        }
250        semantic_router.persist_config()?;
251        Ok(semantic_router)
252    }
253
254    /// Reconnects to an existing semantic router stored in Redis.
255    ///
256    /// The router configuration (name, routes, routing config) must have been
257    /// previously persisted by a [`SemanticRouter::new`] call.  The vectorizer
258    /// must be supplied by the caller since vectorizers cannot be serialized.
259    ///
260    /// The `dtype` parameter must match the data type used when the router was
261    /// originally created; otherwise the schema comparison will fail.
262    pub fn from_existing<V>(
263        name: impl Into<String>,
264        redis_url: impl Into<String>,
265        vectorizer: V,
266        dtype: VectorDataType,
267    ) -> Result<Self>
268    where
269        V: Vectorizer + 'static,
270    {
271        let name = name.into();
272        let connection = RedisConnectionInfo::new(redis_url);
273        let config_key = router_config_key(&name);
274
275        let client = connection.client()?;
276        let mut conn = client.get_connection()?;
277
278        // Read persisted config
279        let raw: Option<String> = redis::cmd("JSON.GET")
280            .arg(&config_key)
281            .arg(".")
282            .query(&mut conn)?;
283        let raw = raw.ok_or_else(|| {
284            Error::InvalidInput(format!(
285                "No valid router config found for {name}. No persisted configuration exists at key '{config_key}'."
286            ))
287        })?;
288        let config_value: Value = serde_json::from_str(&raw)?;
289        let config_obj = config_value
290            .as_object()
291            .ok_or_else(|| Error::InvalidInput("Router config is not an object".to_owned()))?;
292
293        let routes_value = config_obj
294            .get("routes")
295            .ok_or_else(|| Error::InvalidInput("Router config missing 'routes'".to_owned()))?;
296        let routes: Vec<Route> = serde_json::from_value(routes_value.clone())?;
297        let routing_config_value = config_obj.get("routing_config").ok_or_else(|| {
298            Error::InvalidInput("Router config missing 'routing_config'".to_owned())
299        })?;
300        let routing_config: RoutingConfig = serde_json::from_value(routing_config_value.clone())?;
301
302        let vectorizer = Arc::new(vectorizer);
303
304        // Probe dimensions from the first route reference
305        let first_ref = routes
306            .iter()
307            .flat_map(|r| r.references.iter())
308            .next()
309            .ok_or_else(|| Error::InvalidInput("Persisted routes have no references".to_owned()))?;
310        let probe = vectorizer.embed(first_ref)?;
311        let vector_dimensions = probe.len();
312
313        let schema = router_schema(&name, vector_dimensions, dtype);
314        let index = SearchIndex::from_json_value(schema, connection.redis_url.clone())?;
315        // The index should already exist
316        if !index.exists().unwrap_or(false) {
317            return Err(Error::InvalidInput(format!(
318                "Index '{name}' does not exist in Redis"
319            )));
320        }
321
322        Ok(Self {
323            name,
324            connection,
325            routes,
326            routing_config,
327            dtype,
328            index,
329            vectorizer,
330            vector_dimensions,
331        })
332    }
333
334    /// Returns the names of the configured routes.
335    pub fn route_names(&self) -> Vec<String> {
336        self.routes.iter().map(|route| route.name.clone()).collect()
337    }
338
339    /// Returns the configured per-route thresholds.
340    pub fn route_thresholds(&self) -> HashMap<String, f32> {
341        self.routes
342            .iter()
343            .map(|route| (route.name.clone(), route.effective_threshold()))
344            .collect()
345    }
346
347    /// Returns a route by name when present.
348    pub fn get(&self, route_name: &str) -> Option<&Route> {
349        self.routes.iter().find(|route| route.name == route_name)
350    }
351
352    /// Updates the runtime routing configuration.
353    pub fn update_routing_config(&mut self, routing_config: RoutingConfig) {
354        self.routing_config = routing_config;
355    }
356
357    /// Updates route thresholds in place for any supplied route names.
358    pub fn update_route_thresholds(&mut self, route_thresholds: &HashMap<String, f32>) {
359        for route in &mut self.routes {
360            if let Some(distance_threshold) = route_thresholds.get(&route.name) {
361                route.distance_threshold = Some(*distance_threshold);
362            }
363        }
364    }
365
366    /// Routes a statement or vector to the best matching route.
367    pub fn route(&self, statement: Option<&str>, vector: Option<&[f32]>) -> Result<RouteMatch> {
368        Ok(self
369            .route_many(statement, vector, None, None)?
370            .into_iter()
371            .next()
372            .unwrap_or_else(RouteMatch::no_match))
373    }
374
375    /// Routes a statement or vector to the top matching routes.
376    pub fn route_many(
377        &self,
378        statement: Option<&str>,
379        vector: Option<&[f32]>,
380        max_k: Option<usize>,
381        aggregation_method: Option<DistanceAggregationMethod>,
382    ) -> Result<Vec<RouteMatch>> {
383        let vector = self.resolve_vector(statement, vector)?;
384        let max_threshold = self
385            .routes
386            .iter()
387            .map(Route::effective_threshold)
388            .fold(DEFAULT_ROUTE_DISTANCE_THRESHOLD, f32::max);
389        let reference_count = self
390            .routes
391            .iter()
392            .map(|route| route.references.len())
393            .sum::<usize>()
394            .max(1);
395        let query = VectorRangeQuery::new(Vector::new(vector), ROUTER_VECTOR_FIELD, max_threshold)
396            .paging(0, reference_count)
397            .with_return_fields([ROUTER_ROUTE_NAME_FIELD, ROUTER_REFERENCE_ID_FIELD]);
398
399        let documents = query_output_documents(self.index.query(&query)?)?;
400        let mut grouped: HashMap<String, Vec<f32>> = HashMap::new();
401        for document in documents {
402            let Some(route_name) = document
403                .get(ROUTER_ROUTE_NAME_FIELD)
404                .and_then(Value::as_str)
405                .map(str::to_owned)
406            else {
407                continue;
408            };
409            let Some(distance) = parse_distance(document.get("vector_distance")) else {
410                continue;
411            };
412            grouped.entry(route_name).or_default().push(distance);
413        }
414
415        let aggregation_method =
416            aggregation_method.unwrap_or(self.routing_config.aggregation_method);
417        let mut matches = self
418            .routes
419            .iter()
420            .filter_map(|route| {
421                let distances = grouped.get(&route.name)?;
422                let distance = aggregate_distances(distances, aggregation_method);
423                (distance <= route.effective_threshold()).then(|| RouteMatch {
424                    name: Some(route.name.clone()),
425                    distance: Some(distance),
426                })
427            })
428            .collect::<Vec<_>>();
429
430        matches.sort_by(|left, right| {
431            let left = left.distance.unwrap_or(f32::INFINITY);
432            let right = right.distance.unwrap_or(f32::INFINITY);
433            left.total_cmp(&right)
434        });
435        matches.truncate(max_k.unwrap_or(self.routing_config.max_k));
436        Ok(matches)
437    }
438
439    /// Adds routes to the router and loads their references into Redis.
440    pub fn add_routes(&mut self, routes: &[Route]) -> Result<()> {
441        for route in routes {
442            validate_route(route)?;
443        }
444        self.load_route_batch(routes)?;
445        for route in routes {
446            if self.get(&route.name).is_none() {
447                self.routes.push(route.clone());
448            }
449        }
450        Ok(())
451    }
452
453    /// Removes a route and its references from Redis.
454    pub fn remove_route(&mut self, route_name: &str) -> Result<()> {
455        let Some(route) = self.get(route_name).cloned() else {
456            return Ok(());
457        };
458
459        let keys = route
460            .references
461            .iter()
462            .map(|reference| self.index.key(&route_reference_id(&route.name, reference)))
463            .collect::<Vec<_>>();
464        self.index.drop_keys(&keys)?;
465        self.routes.retain(|route| route.name != route_name);
466        self.persist_config()?;
467        Ok(())
468    }
469
470    /// Clears all route references while preserving the index.
471    pub fn clear(&mut self) -> Result<usize> {
472        let deleted = self.index.clear()?;
473        self.routes.clear();
474        Ok(deleted)
475    }
476
477    /// Deletes the router index, its documents, and the persisted config key.
478    pub fn delete(&self) -> Result<()> {
479        // Remove persisted config
480        let config_key = router_config_key(&self.name);
481        let client = self.connection.client()?;
482        let mut conn = client.get_connection()?;
483        let _: usize = redis::cmd("DEL").arg(&config_key).query(&mut conn)?;
484        // Tolerate the index already being deleted (e.g. by another router
485        // instance pointing at the same name after a YAML/dict round-trip).
486        match self.index.delete(true) {
487            Ok(()) => Ok(()),
488            Err(Error::InvalidInput(msg)) if msg.contains("does not exist") => Ok(()),
489            Err(other) => Err(other),
490        }
491    }
492
493    /// Adds new references to an existing route and loads them into Redis.
494    ///
495    /// Returns the list of Redis keys created for the added references.
496    pub fn add_route_references(
497        &mut self,
498        route_name: &str,
499        references: &[String],
500    ) -> Result<Vec<String>> {
501        if references.is_empty() {
502            return Ok(Vec::new());
503        }
504
505        // Validate route exists before doing any work
506        if self.get(route_name).is_none() {
507            return Err(crate::Error::InvalidInput(format!(
508                "Route '{route_name}' not found in the SemanticRouter"
509            )));
510        }
511
512        let refs_str: Vec<&str> = references.iter().map(String::as_str).collect();
513        let embeddings = self.vectorizer.embed_many(&refs_str)?;
514        let mut records = Vec::with_capacity(references.len());
515        let mut keys = Vec::with_capacity(references.len());
516
517        for (reference, embedding) in references.iter().zip(embeddings) {
518            if embedding.len() != self.vector_dimensions {
519                return Err(crate::Error::InvalidInput(format!(
520                    "router vector dimensions mismatch: expected {}, got {}",
521                    self.vector_dimensions,
522                    embedding.len()
523                )));
524            }
525            let ref_id = route_reference_id(route_name, reference);
526            keys.push(self.index.key(&ref_id));
527            records.push(json!({
528                ROUTER_REFERENCE_ID_FIELD: ref_id,
529                ROUTER_ROUTE_NAME_FIELD: route_name,
530                ROUTER_REFERENCE_FIELD: reference,
531                ROUTER_VECTOR_FIELD: embedding,
532            }));
533        }
534
535        if !records.is_empty() {
536            let _: Vec<String> = self.index.load(&records, ROUTER_REFERENCE_ID_FIELD, None)?;
537        }
538
539        // Update the in-memory route with the new references
540        if let Some(route) = self.routes.iter_mut().find(|r| r.name == route_name) {
541            route.references.extend(references.iter().cloned());
542        }
543        self.persist_config()?;
544
545        Ok(keys)
546    }
547
548    /// Retrieves reference metadata for a route by name or by specific reference IDs.
549    ///
550    /// At least one of `route_name` or `reference_ids` must be provided.
551    pub fn get_route_references(
552        &self,
553        route_name: Option<&str>,
554        reference_ids: Option<&[String]>,
555    ) -> Result<Vec<Map<String, Value>>> {
556        let ids_to_query: Vec<String> = if let Some(ref_ids) = reference_ids {
557            ref_ids.to_vec()
558        } else if let Some(route_name) = route_name {
559            let pattern = self.route_pattern(route_name);
560            let scanned = self.scan_keys(&pattern)?;
561            let sep = self.index.key_separator();
562            let prefix = self.index.prefix();
563            let prefix_with_sep = if prefix.ends_with(sep) {
564                prefix.to_owned()
565            } else {
566                format!("{prefix}{sep}")
567            };
568            // Strip the prefix to recover the reference_id ("route_name:hash")
569            scanned
570                .into_iter()
571                .map(|key| {
572                    key.strip_prefix(&prefix_with_sep)
573                        .unwrap_or(&key)
574                        .to_owned()
575                })
576                .collect()
577        } else {
578            return Err(crate::Error::InvalidInput(
579                "Must provide a route name, reference ids, or keys to get references".to_owned(),
580            ));
581        };
582
583        let queries: Vec<crate::query::FilterQuery> = ids_to_query
584            .iter()
585            .map(|id| {
586                let filter = crate::filter::Tag::new(ROUTER_REFERENCE_ID_FIELD).eq(id.as_str());
587                crate::query::FilterQuery::new(filter).with_return_fields([
588                    ROUTER_REFERENCE_ID_FIELD,
589                    ROUTER_ROUTE_NAME_FIELD,
590                    ROUTER_REFERENCE_FIELD,
591                ])
592            })
593            .collect();
594
595        let results = self.index.batch_query(queries.iter())?;
596        let mut refs = Vec::new();
597        for result in results {
598            if let QueryOutput::Documents(docs) = result {
599                for doc in docs {
600                    refs.push(doc);
601                }
602            }
603        }
604        Ok(refs)
605    }
606
607    /// Deletes route references by route name, reference IDs, or explicit keys.
608    ///
609    /// Returns the number of references deleted.
610    pub fn delete_route_references(
611        &mut self,
612        route_name: Option<&str>,
613        reference_ids: Option<&[String]>,
614        keys: Option<&[String]>,
615    ) -> Result<usize> {
616        let keys_to_delete: Vec<String> = if let Some(explicit_keys) = keys {
617            explicit_keys.to_vec()
618        } else if let Some(ref_ids) = reference_ids {
619            // Query to find the full keys for these reference IDs
620            let queries: Vec<crate::query::FilterQuery> = ref_ids
621                .iter()
622                .map(|id| {
623                    let filter = crate::filter::Tag::new(ROUTER_REFERENCE_ID_FIELD).eq(id.as_str());
624                    crate::query::FilterQuery::new(filter).with_return_fields([
625                        ROUTER_REFERENCE_ID_FIELD,
626                        ROUTER_ROUTE_NAME_FIELD,
627                        ROUTER_REFERENCE_FIELD,
628                    ])
629                })
630                .collect();
631
632            let results = self.index.batch_query(queries.iter())?;
633            let mut found_keys = Vec::new();
634            for result in results {
635                if let QueryOutput::Documents(docs) = result {
636                    for doc in docs {
637                        // reference_id already contains "route_name:hash"
638                        // (produced by route_reference_id), so use it directly
639                        if let Some(ref_id) =
640                            doc.get(ROUTER_REFERENCE_ID_FIELD).and_then(Value::as_str)
641                        {
642                            found_keys.push(self.index.key(ref_id));
643                        }
644                    }
645                }
646            }
647            found_keys
648        } else if let Some(route_name) = route_name {
649            let pattern = self.route_pattern(route_name);
650            self.scan_keys(&pattern)?
651        } else {
652            return Err(crate::Error::InvalidInput(
653                "Must provide route_name, reference_ids, or keys to delete references".to_owned(),
654            ));
655        };
656
657        if keys_to_delete.is_empty() {
658            return Ok(0);
659        }
660
661        // Remove matching references from in-memory routes.
662        //
663        // We derive the route name and reference text from the key structure
664        // rather than fetching from Redis, because hash documents contain
665        // binary vector data that cannot be decoded as UTF-8 strings.
666        //
667        // Key format: {prefix}{sep}{route_name}{sep}{sha256_hash}
668        // Reference ID format: {route_name}:{sha256_hash}
669        let sep = self.index.key_separator();
670        let prefix_raw = self.index.prefix().trim_end_matches(sep);
671        let prefix_with_sep = if prefix_raw.is_empty() {
672            String::new()
673        } else {
674            format!("{prefix_raw}{sep}")
675        };
676        for key in &keys_to_delete {
677            let id = key.strip_prefix(&prefix_with_sep).unwrap_or(key);
678            // id is "route_name:hash" — extract route_name from the first ':'
679            if let Some((rname, _hash)) = id.split_once(':') {
680                if let Some(route) = self.routes.iter_mut().find(|r| r.name == rname) {
681                    // Find which reference produces this reference_id
682                    route
683                        .references
684                        .retain(|ref_text| route_reference_id(rname, ref_text) != id);
685                }
686            }
687        }
688
689        let deleted = self.index.drop_keys(&keys_to_delete)?;
690        self.persist_config()?;
691        Ok(deleted)
692    }
693
694    /// Serializes the router to a JSON value (equivalent to Python `to_dict()`).
695    ///
696    /// The `vectorizer` field contains a `{"type": "custom"}` marker since
697    /// Rust vectorizer implementations cannot be serialized.  Use this for
698    /// round-tripping via [`Self::from_dict`].
699    pub fn to_json_value(&self) -> Result<Value> {
700        Ok(json!({
701            "name": self.name,
702            "routes": self.routes,
703            "routing_config": self.routing_config,
704            "vectorizer": {
705                "type": "custom"
706            }
707        }))
708    }
709
710    /// Alias for [`Self::to_json_value`] matching the Python `to_dict()` name.
711    pub fn to_dict(&self) -> Result<Value> {
712        self.to_json_value()
713    }
714
715    /// Writes the router configuration to a YAML file.
716    ///
717    /// If `overwrite` is false and the file already exists, returns an error.
718    pub fn to_yaml(&self, file_path: impl AsRef<Path>, overwrite: bool) -> Result<()> {
719        let path = file_path.as_ref();
720        if path.exists() && !overwrite {
721            return Err(Error::InvalidInput(format!(
722                "Schema file {} already exists.",
723                path.display()
724            )));
725        }
726        let dict = self.to_json_value()?;
727        let file = std::fs::File::create(path)
728            .map_err(|e| Error::InvalidInput(format!("Cannot create file: {e}")))?;
729        serde_yaml::to_writer(file, &dict)
730            .map_err(|e| Error::InvalidInput(format!("YAML serialization error: {e}")))?;
731        Ok(())
732    }
733
734    /// Creates a semantic router from a YAML file.
735    ///
736    /// The caller must supply a vectorizer since vectorizer implementations
737    /// cannot be deserialized from YAML.
738    pub fn from_yaml<V>(
739        file_path: impl AsRef<Path>,
740        redis_url: impl Into<String>,
741        vectorizer: V,
742        dtype: VectorDataType,
743        overwrite: bool,
744    ) -> Result<Self>
745    where
746        V: Vectorizer + 'static,
747    {
748        let path = file_path.as_ref();
749        if !path.exists() {
750            return Err(Error::InvalidInput(format!(
751                "File {} does not exist",
752                path.display()
753            )));
754        }
755        let file = std::fs::File::open(path)
756            .map_err(|e| Error::InvalidInput(format!("Cannot open file: {e}")))?;
757        let dict: Value = serde_yaml::from_reader(file)
758            .map_err(|e| Error::InvalidInput(format!("YAML deserialization error: {e}")))?;
759        Self::from_dict(dict, redis_url, vectorizer, dtype, overwrite)
760    }
761
762    /// Creates a semantic router from a previously serialized JSON value.
763    ///
764    /// The caller must supply a vectorizer since vectorizer implementations
765    /// cannot be deserialized from JSON.
766    pub fn from_dict<V>(
767        data: Value,
768        redis_url: impl Into<String>,
769        vectorizer: V,
770        dtype: VectorDataType,
771        overwrite: bool,
772    ) -> Result<Self>
773    where
774        V: Vectorizer + 'static,
775    {
776        let obj = data
777            .as_object()
778            .ok_or_else(|| Error::InvalidInput("Router dict must be a JSON object".to_owned()))?;
779
780        let name = obj
781            .get("name")
782            .and_then(Value::as_str)
783            .ok_or_else(|| {
784                Error::InvalidInput(
785                    "Unable to load semantic router from dict: missing 'name'".to_owned(),
786                )
787            })?
788            .to_owned();
789
790        let routes_value = obj.get("routes").ok_or_else(|| {
791            Error::InvalidInput(
792                "Unable to load semantic router from dict: missing 'routes'".to_owned(),
793            )
794        })?;
795        let routes: Vec<Route> = serde_json::from_value(routes_value.clone())?;
796
797        let routing_config_value = obj.get("routing_config").ok_or_else(|| {
798            Error::InvalidInput(
799                "Unable to load semantic router from dict: missing 'routing_config'".to_owned(),
800            )
801        })?;
802        let routing_config: RoutingConfig = serde_json::from_value(routing_config_value.clone())?;
803
804        Self::new_with_options(
805            name,
806            redis_url,
807            routes,
808            routing_config,
809            vectorizer,
810            dtype,
811            overwrite,
812        )
813    }
814
815    /// Persists the current router configuration to Redis as a JSON document.
816    fn persist_config(&self) -> Result<()> {
817        let config_key = router_config_key(&self.name);
818        let dict = self.to_json_value()?;
819        let json_str = serde_json::to_string(&dict)?;
820        let client = self.connection.client()?;
821        let mut conn = client.get_connection()?;
822        let _: () = redis::cmd("JSON.SET")
823            .arg(&config_key)
824            .arg(".")
825            .arg(&json_str)
826            .query(&mut conn)?;
827        Ok(())
828    }
829
830    /// Scans Redis keys matching a glob pattern.
831    fn scan_keys(&self, pattern: &str) -> Result<Vec<String>> {
832        let client = self.connection.client()?;
833        let mut connection = client.get_connection()?;
834        let mut cursor = 0_u64;
835        let mut keys = Vec::new();
836        loop {
837            let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
838                .arg(cursor)
839                .arg("MATCH")
840                .arg(pattern)
841                .arg("COUNT")
842                .arg(100)
843                .query(&mut connection)?;
844            keys.extend(batch);
845            if next_cursor == 0 {
846                break;
847            }
848            cursor = next_cursor;
849        }
850        Ok(keys)
851    }
852
853    /// Builds a scan pattern for a route's references.
854    fn route_pattern(&self, route_name: &str) -> String {
855        let sep = self.index.key_separator();
856        let prefix = self.index.prefix().trim_end_matches(sep);
857        if prefix.is_empty() {
858            format!("{route_name}{sep}*")
859        } else {
860            format!("{prefix}{sep}{route_name}{sep}*")
861        }
862    }
863
864    fn load_routes(&self) -> Result<()> {
865        self.load_route_batch(&self.routes)
866    }
867
868    fn load_route_batch(&self, routes: &[Route]) -> Result<()> {
869        for route in routes {
870            validate_route(route)?;
871        }
872
873        let mut records = Vec::new();
874        for route in routes {
875            let refs = route
876                .references
877                .iter()
878                .map(String::as_str)
879                .collect::<Vec<_>>();
880            let embeddings = self.vectorizer.embed_many(&refs)?;
881            for (reference, embedding) in route.references.iter().zip(embeddings) {
882                if embedding.len() != self.vector_dimensions {
883                    return Err(crate::Error::InvalidInput(format!(
884                        "router vector dimensions mismatch: expected {}, got {}",
885                        self.vector_dimensions,
886                        embedding.len()
887                    )));
888                }
889                records.push(json!({
890                    ROUTER_REFERENCE_ID_FIELD: route_reference_id(&route.name, reference),
891                    ROUTER_ROUTE_NAME_FIELD: route.name,
892                    ROUTER_REFERENCE_FIELD: reference,
893                    ROUTER_VECTOR_FIELD: embedding,
894                }));
895            }
896        }
897
898        if !records.is_empty() {
899            let _: Vec<String> = self.index.load(&records, ROUTER_REFERENCE_ID_FIELD, None)?;
900        }
901        Ok(())
902    }
903
904    fn resolve_vector(&self, statement: Option<&str>, vector: Option<&[f32]>) -> Result<Vec<f32>> {
905        match (statement, vector) {
906            (_, Some(vector)) => {
907                if vector.len() != self.vector_dimensions {
908                    return Err(crate::Error::InvalidInput(format!(
909                        "router vector dimensions mismatch: expected {}, got {}",
910                        self.vector_dimensions,
911                        vector.len()
912                    )));
913                }
914                Ok(vector.to_vec())
915            }
916            (Some(statement), None) => {
917                let vector = self.vectorizer.embed(statement)?;
918                if vector.len() != self.vector_dimensions {
919                    return Err(crate::Error::InvalidInput(format!(
920                        "router vector dimensions mismatch: expected {}, got {}",
921                        self.vector_dimensions,
922                        vector.len()
923                    )));
924                }
925                Ok(vector)
926            }
927            (None, None) => Err(crate::Error::InvalidInput(
928                "must provide a statement or vector to the router".to_owned(),
929            )),
930        }
931    }
932}
933
934impl std::fmt::Debug for SemanticRouter {
935    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
936        formatter
937            .debug_struct("SemanticRouter")
938            .field("name", &self.name)
939            .field("routes", &self.routes)
940            .field("routing_config", &self.routing_config)
941            .field("vector_dimensions", &self.vector_dimensions)
942            .finish()
943    }
944}
945
946fn validate_route(route: &Route) -> Result<()> {
947    if route.name.trim().is_empty() {
948        return Err(crate::Error::InvalidInput(
949            "route name must not be empty".to_owned(),
950        ));
951    }
952    if route.references.is_empty() {
953        return Err(crate::Error::InvalidInput(
954            "route references must not be empty".to_owned(),
955        ));
956    }
957    if route
958        .references
959        .iter()
960        .any(|reference| reference.trim().is_empty())
961    {
962        return Err(crate::Error::InvalidInput(
963            "route references must not contain empty strings".to_owned(),
964        ));
965    }
966    let threshold = route.effective_threshold();
967    if !(0.0..=2.0).contains(&threshold) {
968        return Err(crate::Error::InvalidInput(format!(
969            "route distance threshold must be between 0 and 2, got {threshold}"
970        )));
971    }
972    Ok(())
973}
974
975fn router_config_key(name: &str) -> String {
976    format!("{name}:route_config")
977}
978
979fn router_schema(name: &str, vector_dimensions: usize, dtype: VectorDataType) -> Value {
980    json!({
981        "index": {
982            "name": name,
983            "prefix": name,
984            "storage_type": "hash",
985        },
986        "fields": [
987            { "name": ROUTER_REFERENCE_ID_FIELD, "type": "tag" },
988            { "name": ROUTER_ROUTE_NAME_FIELD, "type": "tag" },
989            { "name": ROUTER_REFERENCE_FIELD, "type": "text" },
990            {
991                "name": ROUTER_VECTOR_FIELD,
992                "type": "vector",
993                "attrs": {
994                    "algorithm": "flat",
995                    "dims": vector_dimensions,
996                    "datatype": dtype.as_str(),
997                    "distance_metric": "cosine"
998                }
999            }
1000        ]
1001    })
1002}
1003
1004fn route_reference_id(route_name: &str, reference: &str) -> String {
1005    format!("{route_name}:{}", hashify(reference))
1006}
1007
1008fn hashify(content: &str) -> String {
1009    use sha2::{Digest, Sha256};
1010
1011    let mut hasher = Sha256::new();
1012    hasher.update(content.as_bytes());
1013    let digest = hasher.finalize();
1014    let mut output = String::with_capacity(digest.len() * 2);
1015    for byte in digest {
1016        use std::fmt::Write as _;
1017        let _ = write!(&mut output, "{byte:02x}");
1018    }
1019    output
1020}
1021
1022fn query_output_documents(output: QueryOutput) -> Result<Vec<Map<String, Value>>> {
1023    match output {
1024        QueryOutput::Documents(documents) => Ok(documents),
1025        QueryOutput::Count(_) => Err(crate::Error::InvalidInput(
1026            "router queries must return documents".to_owned(),
1027        )),
1028    }
1029}
1030
1031fn parse_distance(value: Option<&Value>) -> Option<f32> {
1032    match value {
1033        Some(Value::Number(number)) => number.as_f64().map(|value| value as f32),
1034        Some(Value::String(value)) => value.parse::<f32>().ok(),
1035        _ => None,
1036    }
1037}
1038
1039fn aggregate_distances(distances: &[f32], aggregation_method: DistanceAggregationMethod) -> f32 {
1040    match aggregation_method {
1041        DistanceAggregationMethod::Avg => distances.iter().sum::<f32>() / distances.len() as f32,
1042        DistanceAggregationMethod::Min => distances.iter().copied().fold(f32::INFINITY, f32::min),
1043        DistanceAggregationMethod::Sum => distances.iter().sum::<f32>(),
1044    }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049    use super::*;
1050
1051    #[test]
1052    fn route_new_defaults() {
1053        let route = Route::new("test", vec!["ref1".to_owned()]);
1054        assert_eq!(route.name, "test");
1055        assert_eq!(route.references, vec!["ref1".to_owned()]);
1056        assert!(route.metadata.is_empty());
1057        assert!(route.distance_threshold.is_none());
1058    }
1059
1060    #[test]
1061    fn route_effective_threshold_default() {
1062        let route = Route::new("test", vec!["ref1".to_owned()]);
1063        assert_eq!(
1064            route.effective_threshold(),
1065            DEFAULT_ROUTE_DISTANCE_THRESHOLD
1066        );
1067    }
1068
1069    #[test]
1070    fn route_effective_threshold_custom() {
1071        let route = Route {
1072            distance_threshold: Some(0.3),
1073            ..Route::new("test", vec!["ref1".to_owned()])
1074        };
1075        assert_eq!(route.effective_threshold(), 0.3);
1076    }
1077
1078    #[test]
1079    fn route_match_no_match() {
1080        let rm = RouteMatch::no_match();
1081        assert!(rm.name.is_none());
1082        assert!(rm.distance.is_none());
1083    }
1084
1085    #[test]
1086    fn routing_config_default() {
1087        let config = RoutingConfig::default();
1088        assert_eq!(config.max_k, 1);
1089        assert_eq!(config.aggregation_method, DistanceAggregationMethod::Avg);
1090    }
1091
1092    #[test]
1093    fn routing_config_serde_roundtrip() {
1094        let config = RoutingConfig {
1095            max_k: 5,
1096            aggregation_method: DistanceAggregationMethod::Min,
1097        };
1098        let json = serde_json::to_string(&config).unwrap();
1099        let deserialized: RoutingConfig = serde_json::from_str(&json).unwrap();
1100        assert_eq!(deserialized.max_k, 5);
1101        assert_eq!(
1102            deserialized.aggregation_method,
1103            DistanceAggregationMethod::Min
1104        );
1105    }
1106
1107    #[test]
1108    fn validate_route_ok() {
1109        let route = Route {
1110            distance_threshold: Some(0.3),
1111            ..Route::new("greeting", vec!["hello".to_owned()])
1112        };
1113        assert!(validate_route(&route).is_ok());
1114    }
1115
1116    #[test]
1117    fn validate_route_empty_name() {
1118        let route = Route::new("", vec!["hello".to_owned()]);
1119        assert!(validate_route(&route).is_err());
1120    }
1121
1122    #[test]
1123    fn validate_route_whitespace_name() {
1124        let route = Route::new("  ", vec!["hello".to_owned()]);
1125        assert!(validate_route(&route).is_err());
1126    }
1127
1128    #[test]
1129    fn validate_route_empty_references() {
1130        let route = Route::new("test", vec![]);
1131        assert!(validate_route(&route).is_err());
1132    }
1133
1134    #[test]
1135    fn validate_route_empty_reference_string() {
1136        let route = Route::new("test", vec!["".to_owned()]);
1137        assert!(validate_route(&route).is_err());
1138    }
1139
1140    #[test]
1141    fn validate_route_bad_threshold() {
1142        let route = Route {
1143            distance_threshold: Some(-0.1),
1144            ..Route::new("test", vec!["hello".to_owned()])
1145        };
1146        assert!(validate_route(&route).is_err());
1147
1148        let route = Route {
1149            distance_threshold: Some(2.5),
1150            ..Route::new("test", vec!["hello".to_owned()])
1151        };
1152        assert!(validate_route(&route).is_err());
1153    }
1154
1155    #[test]
1156    fn route_reference_id_deterministic() {
1157        let id1 = route_reference_id("greeting", "hello");
1158        let id2 = route_reference_id("greeting", "hello");
1159        assert_eq!(id1, id2);
1160    }
1161
1162    #[test]
1163    fn route_reference_id_different_for_different_refs() {
1164        let id1 = route_reference_id("greeting", "hello");
1165        let id2 = route_reference_id("greeting", "hi");
1166        assert_ne!(id1, id2);
1167    }
1168
1169    #[test]
1170    fn route_reference_id_different_for_different_routes() {
1171        let id1 = route_reference_id("greeting", "hello");
1172        let id2 = route_reference_id("farewell", "hello");
1173        assert_ne!(id1, id2);
1174    }
1175
1176    #[test]
1177    fn hashify_deterministic() {
1178        assert_eq!(hashify("test"), hashify("test"));
1179        assert_ne!(hashify("test"), hashify("other"));
1180    }
1181
1182    #[test]
1183    fn aggregate_distances_avg() {
1184        let result = aggregate_distances(&[0.1, 0.3], DistanceAggregationMethod::Avg);
1185        assert!((result - 0.2).abs() < 1e-6);
1186    }
1187
1188    #[test]
1189    fn aggregate_distances_min() {
1190        let result = aggregate_distances(&[0.3, 0.1, 0.5], DistanceAggregationMethod::Min);
1191        assert!((result - 0.1).abs() < 1e-6);
1192    }
1193
1194    #[test]
1195    fn aggregate_distances_sum() {
1196        let result = aggregate_distances(&[0.1, 0.2, 0.3], DistanceAggregationMethod::Sum);
1197        assert!((result - 0.6).abs() < 1e-6);
1198    }
1199
1200    #[test]
1201    fn parse_distance_number() {
1202        let val = Value::Number(serde_json::Number::from_f64(0.5).unwrap());
1203        assert_eq!(parse_distance(Some(&val)), Some(0.5));
1204    }
1205
1206    #[test]
1207    fn parse_distance_string() {
1208        let val = Value::String("0.25".to_owned());
1209        assert_eq!(parse_distance(Some(&val)), Some(0.25));
1210    }
1211
1212    #[test]
1213    fn parse_distance_none() {
1214        assert_eq!(parse_distance(None), None);
1215    }
1216
1217    #[test]
1218    fn parse_distance_invalid() {
1219        let val = Value::String("not_a_number".to_owned());
1220        assert_eq!(parse_distance(Some(&val)), None);
1221    }
1222
1223    #[test]
1224    fn router_schema_structure() {
1225        let schema = router_schema("my_router", 64, VectorDataType::Float32);
1226        assert_eq!(schema["index"]["name"], "my_router");
1227        assert_eq!(schema["index"]["prefix"], "my_router");
1228        assert_eq!(schema["index"]["storage_type"], "hash");
1229
1230        let fields = schema["fields"].as_array().unwrap();
1231        let field_names: Vec<&str> = fields.iter().filter_map(|f| f["name"].as_str()).collect();
1232        assert!(field_names.contains(&"reference_id"));
1233        assert!(field_names.contains(&"route_name"));
1234        assert!(field_names.contains(&"reference"));
1235        assert!(field_names.contains(&"vector"));
1236
1237        let vector_field = fields
1238            .iter()
1239            .find(|f| f["name"].as_str() == Some("vector"))
1240            .unwrap();
1241        assert_eq!(vector_field["attrs"]["dims"], 64);
1242        assert_eq!(vector_field["attrs"]["datatype"], "float32");
1243    }
1244
1245    #[test]
1246    fn distance_aggregation_method_serde() {
1247        let json = serde_json::to_string(&DistanceAggregationMethod::Min).unwrap();
1248        assert_eq!(json, "\"min\"");
1249        let deserialized: DistanceAggregationMethod = serde_json::from_str(&json).unwrap();
1250        assert_eq!(deserialized, DistanceAggregationMethod::Min);
1251    }
1252
1253    #[test]
1254    fn route_serde_roundtrip() {
1255        let route = Route {
1256            name: "test".to_owned(),
1257            references: vec!["hello".to_owned(), "hi".to_owned()],
1258            metadata: serde_json::Map::from_iter([("type".to_owned(), json!("greeting"))]),
1259            distance_threshold: Some(0.3),
1260        };
1261        let json = serde_json::to_string(&route).unwrap();
1262        let deserialized: Route = serde_json::from_str(&json).unwrap();
1263        assert_eq!(deserialized.name, "test");
1264        assert_eq!(deserialized.references, vec!["hello", "hi"]);
1265        assert_eq!(deserialized.distance_threshold, Some(0.3));
1266    }
1267
1268    #[test]
1269    fn router_config_key_format() {
1270        assert_eq!(router_config_key("my_router"), "my_router:route_config");
1271    }
1272
1273    #[test]
1274    fn router_schema_respects_dtype() {
1275        let schema_f64 = router_schema("my_router", 64, VectorDataType::Float64);
1276        let fields = schema_f64["fields"].as_array().unwrap();
1277        let vector_field = fields
1278            .iter()
1279            .find(|f| f["name"].as_str() == Some("vector"))
1280            .unwrap();
1281        assert_eq!(vector_field["attrs"]["datatype"], "float64");
1282
1283        let schema_bfloat16 = router_schema("my_router", 64, VectorDataType::Bfloat16);
1284        let fields = schema_bfloat16["fields"].as_array().unwrap();
1285        let vector_field = fields
1286            .iter()
1287            .find(|f| f["name"].as_str() == Some("vector"))
1288            .unwrap();
1289        assert_eq!(vector_field["attrs"]["datatype"], "bfloat16");
1290
1291        let schema_float16 = router_schema("my_router", 64, VectorDataType::Float16);
1292        let fields = schema_float16["fields"].as_array().unwrap();
1293        let vector_field = fields
1294            .iter()
1295            .find(|f| f["name"].as_str() == Some("vector"))
1296            .unwrap();
1297        assert_eq!(vector_field["attrs"]["datatype"], "float16");
1298    }
1299}