Skip to main content

haystack_server/
federation.rs

1//! Federation manager — coordinates multiple remote connectors for federated queries.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::connector::{Connector, ConnectorConfig, ConnectorState};
7use crate::domain_scope::DomainScope;
8use haystack_core::data::HDict;
9use haystack_core::filter::parse_filter;
10
11/// TOML file structure for federation configuration.
12///
13/// Example:
14/// ```toml
15/// [connectors.building-a]
16/// name = "Building A"
17/// url = "http://building-a:8080/api"
18/// username = "federation"
19/// password = "s3cret"
20/// ```
21#[derive(serde::Deserialize)]
22struct FederationToml {
23    connectors: HashMap<String, ConnectorConfig>,
24}
25
26/// Manages multiple remote connectors for federated queries.
27pub struct Federation {
28    pub connectors: Vec<Arc<Connector>>,
29}
30
31impl Federation {
32    /// Create a new federation with no connectors.
33    pub fn new() -> Self {
34        Self {
35            connectors: Vec::new(),
36        }
37    }
38
39    /// Add a connector for a remote Haystack server.
40    pub fn add(&mut self, config: ConnectorConfig) -> Result<(), String> {
41        config.validate()?;
42        self.connectors.push(Arc::new(Connector::new(config)));
43        Ok(())
44    }
45
46    /// Sync a single connector by name, returning the entity count on success.
47    pub async fn sync_one(&self, name: &str) -> Result<usize, String> {
48        for connector in &self.connectors {
49            if connector.config.name == name {
50                return connector.sync().await;
51            }
52        }
53        Err(format!("connector not found: {name}"))
54    }
55
56    /// Sync all connectors, returning a vec of (name, result) pairs.
57    ///
58    /// Each result is either `Ok(count)` with the number of entities synced,
59    /// or `Err(message)` with the error description.
60    pub async fn sync_all(&self) -> Vec<(String, Result<usize, String>)> {
61        let mut results = Vec::new();
62        for connector in &self.connectors {
63            let name = connector.config.name.clone();
64            let result = connector.sync().await;
65            results.push((name, result));
66        }
67        results
68    }
69
70    /// Returns all cached entities from all connectors, merged into a single vec.
71    pub fn all_cached_entities(&self) -> Vec<Arc<HDict>> {
72        let mut all = Vec::new();
73        for connector in &self.connectors {
74            all.extend(connector.cached_entities());
75        }
76        all
77    }
78
79    /// Filter cached entities across all connectors using bitmap-accelerated queries.
80    ///
81    /// Each connector uses its own bitmap tag index for fast filtering, then
82    /// results are merged up to the given limit. Much faster than linear scan
83    /// for tag-based queries over large federated entity sets.
84    pub fn filter_cached_entities(
85        &self,
86        filter_expr: &str,
87        limit: usize,
88    ) -> Result<Vec<Arc<HDict>>, String> {
89        let effective_limit = if limit == 0 { usize::MAX } else { limit };
90        let ast = parse_filter(filter_expr).map_err(|e| format!("filter error: {e}"))?;
91
92        let mut all = Vec::new();
93        for connector in &self.connectors {
94            if all.len() >= effective_limit {
95                break;
96            }
97            let remaining = effective_limit - all.len();
98            all.extend(connector.filter_cached_with_ast(&ast, remaining));
99        }
100        Ok(all)
101    }
102
103    /// Returns the number of connectors.
104    pub fn connector_count(&self) -> usize {
105        self.connectors.len()
106    }
107
108    /// Returns the connector that owns the entity with the given ID, if any.
109    pub fn owner_of(&self, id: &str) -> Option<&Arc<Connector>> {
110        self.connectors.iter().find(|c| c.owns(id))
111    }
112
113    /// Get connectors that match a domain scope.
114    pub fn connectors_for_scope(&self, scope: &DomainScope) -> Vec<&Arc<Connector>> {
115        self.connectors
116            .iter()
117            .filter(|c| scope.includes(c.config.domain.as_deref()))
118            .collect()
119    }
120
121    /// Get cached entities from connectors matching the scope.
122    pub fn cached_entities_for_scope(&self, scope: &DomainScope) -> Vec<Arc<HDict>> {
123        self.connectors_for_scope(scope)
124            .iter()
125            .flat_map(|c| c.cached_entities())
126            .collect()
127    }
128
129    /// Returns observable state for each connector.
130    pub fn connector_states(&self) -> Vec<ConnectorState> {
131        self.connectors.iter().map(|c| c.state()).collect()
132    }
133
134    /// Batch read entities by ID across federated connectors.
135    ///
136    /// Groups IDs by owning connector and fetches each group in a single
137    /// indexed lookup (O(1) per ID via `cache_id_map`), avoiding repeated
138    /// linear scans. Returns `(found_entities, missing_ids)`.
139    pub fn batch_read_by_id<'a>(
140        &self,
141        ids: impl IntoIterator<Item = &'a str>,
142    ) -> (Vec<Arc<HDict>>, Vec<String>) {
143        // Group IDs by connector index.
144        let mut groups: HashMap<usize, Vec<&str>> = HashMap::new();
145        let mut not_owned: Vec<String> = Vec::new();
146
147        for id in ids {
148            let mut found = false;
149            for (idx, connector) in self.connectors.iter().enumerate() {
150                if connector.owns(id) {
151                    groups.entry(idx).or_default().push(id);
152                    found = true;
153                    break;
154                }
155            }
156            if !found {
157                not_owned.push(id.to_string());
158            }
159        }
160
161        // Fetch each group from its connector in a single pass.
162        let mut all_found = Vec::new();
163        for (idx, ids) in &groups {
164            let (found, mut missing) = self.connectors[*idx].batch_get_cached(ids);
165            all_found.extend(found);
166            not_owned.append(&mut missing);
167        }
168
169        (all_found, not_owned)
170    }
171
172    /// Returns `(name, entity_count)` for each connector.
173    pub fn status(&self) -> Vec<(String, usize)> {
174        self.connectors
175            .iter()
176            .map(|c| (c.config.name.clone(), c.entity_count()))
177            .collect()
178    }
179
180    /// Parse a TOML string into a `Federation`, adding each connector defined
181    /// under `[connectors.<key>]`.
182    pub fn from_toml_str(toml_str: &str) -> Result<Self, String> {
183        let parsed: FederationToml =
184            toml::from_str(toml_str).map_err(|e| format!("invalid federation TOML: {e}"))?;
185        let mut fed = Self::new();
186        for (_key, config) in parsed.connectors {
187            fed.add(config)?;
188        }
189        Ok(fed)
190    }
191
192    /// Read a TOML file from disk and parse it into a `Federation`.
193    pub fn from_toml_file(path: &str) -> Result<Self, String> {
194        let contents =
195            std::fs::read_to_string(path).map_err(|e| format!("failed to read {path}: {e}"))?;
196        Self::from_toml_str(&contents)
197    }
198
199    /// Start background sync tasks for all connectors.
200    ///
201    /// Each connector gets its own tokio task that loops at its configured
202    /// sync interval, reconnecting automatically on failure.
203    /// Returns the join handles (they run until the server shuts down).
204    pub fn start_background_sync(&self) -> Vec<tokio::task::JoinHandle<()>> {
205        self.connectors
206            .iter()
207            .map(|c| Connector::spawn_sync_task(Arc::clone(c)))
208            .collect()
209    }
210}
211
212impl Default for Federation {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use haystack_core::kinds::{HRef, Kind};
222
223    #[test]
224    fn federation_new_empty() {
225        let fed = Federation::new();
226        assert_eq!(fed.connector_count(), 0);
227        assert!(fed.all_cached_entities().is_empty());
228        assert!(fed.status().is_empty());
229    }
230
231    #[test]
232    fn federation_add_connector() {
233        let mut fed = Federation::new();
234        assert_eq!(fed.connector_count(), 0);
235
236        fed.add(ConnectorConfig {
237            name: "server-1".to_string(),
238            url: "http://localhost:8080/api".to_string(),
239            username: "user".to_string(),
240            password: "pass".to_string(),
241            id_prefix: None,
242            ws_url: None,
243            sync_interval_secs: None,
244            client_cert: None,
245            client_key: None,
246            ca_cert: None,
247            domain: None,
248        })
249        .unwrap();
250        assert_eq!(fed.connector_count(), 1);
251
252        fed.add(ConnectorConfig {
253            name: "server-2".to_string(),
254            url: "http://localhost:8081/api".to_string(),
255            username: "user".to_string(),
256            password: "pass".to_string(),
257            id_prefix: Some("s2-".to_string()),
258            ws_url: None,
259            sync_interval_secs: None,
260            client_cert: None,
261            client_key: None,
262            ca_cert: None,
263            domain: None,
264        })
265        .unwrap();
266        assert_eq!(fed.connector_count(), 2);
267    }
268
269    #[test]
270    fn federation_status_empty() {
271        let fed = Federation::new();
272        let status = fed.status();
273        assert!(status.is_empty());
274    }
275
276    #[test]
277    fn federation_status_with_connectors() {
278        let mut fed = Federation::new();
279        fed.add(ConnectorConfig {
280            name: "alpha".to_string(),
281            url: "http://alpha:8080/api".to_string(),
282            username: "user".to_string(),
283            password: "pass".to_string(),
284            id_prefix: None,
285            ws_url: None,
286            sync_interval_secs: None,
287            client_cert: None,
288            client_key: None,
289            ca_cert: None,
290            domain: None,
291        })
292        .unwrap();
293        fed.add(ConnectorConfig {
294            name: "beta".to_string(),
295            url: "http://beta:8080/api".to_string(),
296            username: "user".to_string(),
297            password: "pass".to_string(),
298            id_prefix: Some("b-".to_string()),
299            ws_url: None,
300            sync_interval_secs: None,
301            client_cert: None,
302            client_key: None,
303            ca_cert: None,
304            domain: None,
305        })
306        .unwrap();
307
308        let status = fed.status();
309        assert_eq!(status.len(), 2);
310        assert_eq!(status[0].0, "alpha");
311        assert_eq!(status[0].1, 0); // no sync yet
312        assert_eq!(status[1].0, "beta");
313        assert_eq!(status[1].1, 0);
314    }
315
316    #[test]
317    fn federation_owner_of_returns_correct_connector() {
318        let mut fed = Federation::new();
319        fed.add(ConnectorConfig {
320            name: "alpha".to_string(),
321            url: "http://alpha:8080/api".to_string(),
322            username: "user".to_string(),
323            password: "pass".to_string(),
324            id_prefix: Some("a-".to_string()),
325            ws_url: None,
326            sync_interval_secs: None,
327            client_cert: None,
328            client_key: None,
329            ca_cert: None,
330            domain: None,
331        })
332        .unwrap();
333
334        // Simulate cache population for alpha
335        fed.connectors[0].update_cache(vec![{
336            let mut d = HDict::new();
337            d.set("id", Kind::Ref(HRef::from_val("a-site-1")));
338            d
339        }]);
340
341        assert!(fed.owner_of("a-site-1").is_some());
342        assert_eq!(fed.owner_of("a-site-1").unwrap().config.name, "alpha");
343        assert!(fed.owner_of("unknown-1").is_none());
344    }
345
346    #[test]
347    fn federation_from_toml_str() {
348        let toml = r#"
349[connectors.building-a]
350name = "Building A"
351url = "http://building-a:8080/api"
352username = "federation"
353password = "s3cret"
354id_prefix = "bldg-a-"
355sync_interval_secs = 30
356
357[connectors.building-b]
358name = "Building B"
359url = "https://building-b:8443/api"
360username = "federation"
361password = "s3cret"
362id_prefix = "bldg-b-"
363client_cert = "/etc/certs/federation.pem"
364client_key = "/etc/certs/federation-key.pem"
365ca_cert = "/etc/certs/ca.pem"
366"#;
367        let fed = Federation::from_toml_str(toml).unwrap();
368        assert_eq!(fed.connector_count(), 2);
369        let status = fed.status();
370        let names: Vec<&str> = status.iter().map(|(n, _)| n.as_str()).collect();
371        assert!(names.contains(&"Building A"));
372        assert!(names.contains(&"Building B"));
373    }
374
375    #[test]
376    fn federation_from_toml_str_empty() {
377        let toml = "[connectors]\n";
378        let fed = Federation::from_toml_str(toml).unwrap();
379        assert_eq!(fed.connector_count(), 0);
380    }
381
382    #[test]
383    fn federation_from_toml_str_invalid() {
384        let toml = "not valid toml {{{}";
385        assert!(Federation::from_toml_str(toml).is_err());
386    }
387
388    #[test]
389    fn federation_all_cached_entities_empty() {
390        let mut fed = Federation::new();
391        fed.add(ConnectorConfig {
392            name: "server".to_string(),
393            url: "http://localhost:8080/api".to_string(),
394            username: "user".to_string(),
395            password: "pass".to_string(),
396            id_prefix: None,
397            ws_url: None,
398            sync_interval_secs: None,
399            client_cert: None,
400            client_key: None,
401            ca_cert: None,
402            domain: None,
403        })
404        .unwrap();
405        // No sync performed, so entities are empty.
406        assert!(fed.all_cached_entities().is_empty());
407    }
408
409    #[test]
410    fn cached_entities_for_scope_wildcard() {
411        let mut fed = Federation::new();
412        fed.add(ConnectorConfig {
413            name: "a".to_string(),
414            url: "http://a:8080/api".to_string(),
415            username: "u".to_string(),
416            password: "p".to_string(),
417            id_prefix: Some("a-".to_string()),
418            ws_url: None,
419            sync_interval_secs: None,
420            client_cert: None,
421            client_key: None,
422            ca_cert: None,
423            domain: Some("site-a".to_string()),
424        })
425        .unwrap();
426        fed.add(ConnectorConfig {
427            name: "b".to_string(),
428            url: "http://b:8080/api".to_string(),
429            username: "u".to_string(),
430            password: "p".to_string(),
431            id_prefix: Some("b-".to_string()),
432            ws_url: None,
433            sync_interval_secs: None,
434            client_cert: None,
435            client_key: None,
436            ca_cert: None,
437            domain: Some("site-b".to_string()),
438        })
439        .unwrap();
440
441        // Populate caches
442        let mut e1 = HDict::new();
443        e1.set("id", Kind::Ref(HRef::from_val("a-s1")));
444        fed.connectors[0].update_cache(vec![e1]);
445
446        let mut e2 = HDict::new();
447        e2.set("id", Kind::Ref(HRef::from_val("b-s1")));
448        fed.connectors[1].update_cache(vec![e2]);
449
450        // Wildcard scope returns all
451        let all = fed.cached_entities_for_scope(&DomainScope::all());
452        assert_eq!(all.len(), 2);
453    }
454
455    #[test]
456    fn cached_entities_for_scope_scoped() {
457        let mut fed = Federation::new();
458        fed.add(ConnectorConfig {
459            name: "a".to_string(),
460            url: "http://a:8080/api".to_string(),
461            username: "u".to_string(),
462            password: "p".to_string(),
463            id_prefix: Some("a-".to_string()),
464            ws_url: None,
465            sync_interval_secs: None,
466            client_cert: None,
467            client_key: None,
468            ca_cert: None,
469            domain: Some("site-a".to_string()),
470        })
471        .unwrap();
472        fed.add(ConnectorConfig {
473            name: "b".to_string(),
474            url: "http://b:8080/api".to_string(),
475            username: "u".to_string(),
476            password: "p".to_string(),
477            id_prefix: Some("b-".to_string()),
478            ws_url: None,
479            sync_interval_secs: None,
480            client_cert: None,
481            client_key: None,
482            ca_cert: None,
483            domain: Some("site-b".to_string()),
484        })
485        .unwrap();
486
487        let mut e1 = HDict::new();
488        e1.set("id", Kind::Ref(HRef::from_val("a-s1")));
489        fed.connectors[0].update_cache(vec![e1]);
490
491        let mut e2 = HDict::new();
492        e2.set("id", Kind::Ref(HRef::from_val("b-s1")));
493        fed.connectors[1].update_cache(vec![e2]);
494
495        // Scoped to site-a only
496        let scoped = fed.cached_entities_for_scope(&DomainScope::scoped(["site-a".to_string()]));
497        assert_eq!(scoped.len(), 1);
498    }
499
500    #[test]
501    fn connector_states_populated() {
502        let mut fed = Federation::new();
503        fed.add(ConnectorConfig {
504            name: "alpha".to_string(),
505            url: "http://alpha:8080/api".to_string(),
506            username: "u".to_string(),
507            password: "p".to_string(),
508            id_prefix: None,
509            ws_url: None,
510            sync_interval_secs: None,
511            client_cert: None,
512            client_key: None,
513            ca_cert: None,
514            domain: None,
515        })
516        .unwrap();
517
518        let states = fed.connector_states();
519        assert_eq!(states.len(), 1);
520        assert_eq!(states[0].name, "alpha");
521        assert!(!states[0].connected);
522        assert_eq!(states[0].cache_version, 0);
523        assert_eq!(states[0].entity_count, 0);
524    }
525}