Skip to main content

optirs_learned/
cross_domain_transfer.rs

1//! Cross-domain knowledge transfer between domain optimizers.
2//!
3//! This module enables transferring learned optimization knowledge from one
4//! domain (e.g., computer vision) to another (e.g., NLP) by maintaining
5//! shared representations and computing domain similarity. The transfer
6//! mechanism blends source-domain knowledge into the target domain weighted
7//! by a transferability score derived from cosine similarity of shared
8//! representations.
9
10use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::{Array1, Array2, ScalarOperand, Zip};
12use scirs2_core::numeric::Float;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16// ---------------------------------------------------------------------------
17// Domain knowledge
18// ---------------------------------------------------------------------------
19
20/// Knowledge captured for a single domain.
21///
22/// Separates the shared backbone representation (usable across domains) from
23/// domain-specific parameters, and tracks performance history for analytics.
24#[derive(Debug, Clone)]
25pub struct DomainKnowledge<T: Float + Debug + Send + Sync + 'static> {
26    /// Human-readable domain name.
27    pub domain_name: String,
28    /// Features that form the shared representation across domains.
29    pub shared_representation: Array1<T>,
30    /// Parameters specific to this domain.
31    pub domain_specific_params: Array1<T>,
32    /// Historical performance values recorded in this domain.
33    pub performance_history: Vec<T>,
34}
35
36// ---------------------------------------------------------------------------
37// Shared representation
38// ---------------------------------------------------------------------------
39
40/// Shared backbone features maintained across all registered domains.
41#[derive(Debug, Clone)]
42pub struct SharedRepresentation<T: Float + Debug + Send + Sync + 'static> {
43    /// Shared feature vector.
44    pub features: Array1<T>,
45    /// Dimensionality of the shared features.
46    pub dimension: usize,
47    /// Version counter incremented on every update.
48    pub version: usize,
49}
50
51impl<T: Float + Debug + Send + Sync + 'static> SharedRepresentation<T> {
52    /// Create a new zero-initialised shared representation of the given dimension.
53    pub fn new(dimension: usize) -> Self {
54        Self {
55            features: Array1::<T>::zeros(dimension),
56            dimension,
57            version: 0,
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Transfer result
64// ---------------------------------------------------------------------------
65
66/// Outcome of a cross-domain transfer operation.
67#[derive(Debug, Clone)]
68pub struct TransferResult<T: Float + Debug + Send + Sync + 'static> {
69    /// Parameters produced by the transfer.
70    pub transferred_params: Array1<T>,
71    /// Scalar in [0, 1] indicating how transferable the source is to the target.
72    pub transferability_score: T,
73    /// Name of the source domain.
74    pub source_domain: String,
75    /// Name of the target domain.
76    pub target_domain: String,
77}
78
79// ---------------------------------------------------------------------------
80// Cross-domain transfer engine
81// ---------------------------------------------------------------------------
82
83/// Engine for cross-domain knowledge transfer.
84///
85/// Maintains a registry of domain-specific knowledge and a global shared
86/// representation. Supports computing pairwise domain similarities and
87/// performing transfer of knowledge from one domain to another.
88#[derive(Debug)]
89pub struct CrossDomainTransfer<T: Float + Debug + Send + Sync + 'static> {
90    /// Registered domains keyed by name.
91    domains: HashMap<String, DomainKnowledge<T>>,
92    /// Global shared representation.
93    shared_repr: SharedRepresentation<T>,
94    /// History of transfers performed.
95    transfer_history: Vec<TransferResult<T>>,
96}
97
98impl<T: Float + Debug + Send + Sync + 'static + ScalarOperand> CrossDomainTransfer<T> {
99    /// Create a new engine with a shared representation of the given dimension.
100    pub fn new(shared_dim: usize) -> Self {
101        Self {
102            domains: HashMap::new(),
103            shared_repr: SharedRepresentation::new(shared_dim),
104            transfer_history: Vec::new(),
105        }
106    }
107
108    /// Get a reference to the transfer history.
109    pub fn transfer_history(&self) -> &[TransferResult<T>] {
110        &self.transfer_history
111    }
112
113    /// Get a reference to the shared representation.
114    pub fn shared_representation(&self) -> &SharedRepresentation<T> {
115        &self.shared_repr
116    }
117
118    // -----------------------------------------------------------------
119    // Domain management
120    // -----------------------------------------------------------------
121
122    /// Register a new domain or update an existing one.
123    ///
124    /// # Errors
125    /// Returns `OptimError::InvalidConfig` if the domain's shared
126    /// representation dimension does not match the engine's shared dimension.
127    pub fn register_domain(&mut self, knowledge: DomainKnowledge<T>) -> Result<()> {
128        if knowledge.shared_representation.len() != self.shared_repr.dimension {
129            return Err(OptimError::InvalidConfig(format!(
130                "Shared representation dimension mismatch: expected {}, got {}",
131                self.shared_repr.dimension,
132                knowledge.shared_representation.len()
133            )));
134        }
135        self.domains
136            .insert(knowledge.domain_name.clone(), knowledge);
137        Ok(())
138    }
139
140    /// Return the names of all registered domains (sorted for determinism).
141    pub fn get_registered_domains(&self) -> Vec<&str> {
142        let mut names: Vec<&str> = self.domains.keys().map(|s| s.as_str()).collect();
143        names.sort();
144        names
145    }
146
147    // -----------------------------------------------------------------
148    // Similarity
149    // -----------------------------------------------------------------
150
151    /// Compute the cosine similarity between two domains' shared representations.
152    ///
153    /// The result is in [-1, 1]. If both vectors are zero the similarity is
154    /// defined as zero.
155    ///
156    /// # Errors
157    /// Returns `OptimError::InvalidState` if either domain is not registered.
158    pub fn compute_domain_similarity(&self, source: &str, target: &str) -> Result<T> {
159        let src = self.get_domain(source)?;
160        let tgt = self.get_domain(target)?;
161
162        let dot = dot_product(&src.shared_representation, &tgt.shared_representation);
163        let norm_src = l2_norm(&src.shared_representation);
164        let norm_tgt = l2_norm(&tgt.shared_representation);
165
166        let denom = norm_src * norm_tgt;
167        if denom <= T::zero() {
168            return Ok(T::zero());
169        }
170        Ok(dot / denom)
171    }
172
173    // -----------------------------------------------------------------
174    // Transfer
175    // -----------------------------------------------------------------
176
177    /// Transfer knowledge from `source` domain to `target` domain.
178    ///
179    /// The adapted parameters are computed as:
180    /// ```text
181    /// transferred = target_specific
182    ///             + transferability * (source_shared - target_shared)
183    /// ```
184    /// where `transferability` is `(cosine_similarity + 1) / 2` (mapped to
185    /// [0, 1]).
186    ///
187    /// # Errors
188    /// Returns `OptimError::InvalidState` if either domain is not registered,
189    /// or `OptimError::ComputationError` if domain-specific parameter
190    /// dimensions differ.
191    pub fn transfer(&mut self, source: &str, target: &str) -> Result<TransferResult<T>> {
192        // We need to clone to avoid double borrow.
193        let src = self.get_domain(source)?.clone();
194        let tgt = self.get_domain(target)?.clone();
195
196        if src.domain_specific_params.len() != tgt.domain_specific_params.len() {
197            return Err(OptimError::ComputationError(format!(
198                "Domain-specific parameter dimension mismatch: source {} vs target {}",
199                src.domain_specific_params.len(),
200                tgt.domain_specific_params.len()
201            )));
202        }
203
204        let similarity = self.compute_domain_similarity(source, target)?;
205        // Map [-1, 1] -> [0, 1]
206        let two = T::from(2.0).unwrap_or_else(|| T::one() + T::one());
207        let transferability = (similarity + T::one()) / two;
208
209        // Build transferred parameters
210        let dim = tgt.domain_specific_params.len();
211        let mut transferred = Array1::<T>::zeros(dim);
212
213        // Shared-representation difference (potentially different length from
214        // domain_specific_params). We only use it element-wise up to the
215        // minimum shared dimension, padding the rest with zero.
216        let shared_dim = src
217            .shared_representation
218            .len()
219            .min(tgt.shared_representation.len());
220        let mut shared_diff = Array1::<T>::zeros(dim);
221        for i in 0..shared_dim.min(dim) {
222            shared_diff[i] = src.shared_representation[i] - tgt.shared_representation[i];
223        }
224
225        Zip::from(&mut transferred)
226            .and(&tgt.domain_specific_params)
227            .and(&shared_diff)
228            .for_each(|out, &tgt_p, &sd| {
229                *out = tgt_p + transferability * sd;
230            });
231
232        let result = TransferResult {
233            transferred_params: transferred,
234            transferability_score: transferability,
235            source_domain: source.to_string(),
236            target_domain: target.to_string(),
237        };
238
239        self.transfer_history.push(result.clone());
240        Ok(result)
241    }
242
243    // -----------------------------------------------------------------
244    // Shared representation update
245    // -----------------------------------------------------------------
246
247    /// Update the shared representation using gradients from a specific domain.
248    ///
249    /// Applies a simple gradient-descent step:
250    /// `shared_features -= lr * gradients`
251    /// and also updates the domain's own shared-representation snapshot.
252    ///
253    /// # Errors
254    /// Returns `OptimError::InvalidState` if the domain is not registered, or
255    /// `OptimError::ComputationError` on dimension mismatch.
256    pub fn update_shared_representation(
257        &mut self,
258        domain_name: &str,
259        gradients: &Array1<T>,
260        lr: T,
261    ) -> Result<()> {
262        if gradients.len() != self.shared_repr.dimension {
263            return Err(OptimError::ComputationError(format!(
264                "Gradient dimension {} does not match shared dimension {}",
265                gradients.len(),
266                self.shared_repr.dimension
267            )));
268        }
269
270        // Check domain exists
271        if !self.domains.contains_key(domain_name) {
272            return Err(OptimError::InvalidState(format!(
273                "Domain '{}' is not registered",
274                domain_name
275            )));
276        }
277
278        // Update global shared representation
279        Zip::from(&mut self.shared_repr.features)
280            .and(gradients)
281            .for_each(|f, &g| {
282                *f = *f - lr * g;
283            });
284        self.shared_repr.version += 1;
285
286        // Mirror into the domain's snapshot
287        if let Some(domain) = self.domains.get_mut(domain_name) {
288            Zip::from(&mut domain.shared_representation)
289                .and(&self.shared_repr.features)
290                .for_each(|d, &s| {
291                    *d = s;
292                });
293        }
294
295        Ok(())
296    }
297
298    // -----------------------------------------------------------------
299    // Transferability matrix
300    // -----------------------------------------------------------------
301
302    /// Compute an NxN pairwise transferability (cosine-similarity) matrix.
303    ///
304    /// Rows and columns are ordered by sorted domain name. The diagonal is 1.0
305    /// (self-similarity).
306    ///
307    /// # Errors
308    /// Returns `OptimError::InsufficientData` if fewer than two domains are
309    /// registered.
310    pub fn get_transferability_matrix(&self) -> Result<Array2<T>> {
311        let names = self.get_registered_domains();
312        let n = names.len();
313        if n < 2 {
314            return Err(OptimError::InsufficientData(
315                "Need at least 2 registered domains to build a transferability matrix".into(),
316            ));
317        }
318
319        let mut matrix = Array2::<T>::zeros((n, n));
320        for (i, &src) in names.iter().enumerate() {
321            for (j, &tgt) in names.iter().enumerate() {
322                if i == j {
323                    matrix[[i, j]] = T::one();
324                } else {
325                    matrix[[i, j]] = self.compute_domain_similarity(src, tgt)?;
326                }
327            }
328        }
329        Ok(matrix)
330    }
331
332    // -----------------------------------------------------------------
333    // Private helpers
334    // -----------------------------------------------------------------
335
336    /// Retrieve a domain by name or return an error.
337    fn get_domain(&self, name: &str) -> Result<&DomainKnowledge<T>> {
338        self.domains
339            .get(name)
340            .ok_or_else(|| OptimError::InvalidState(format!("Domain '{}' is not registered", name)))
341    }
342}
343
344// ---------------------------------------------------------------------------
345// Free-standing math helpers
346// ---------------------------------------------------------------------------
347
348/// Dot product of two arrays.
349fn dot_product<T: Float + Debug + Send + Sync + 'static>(a: &Array1<T>, b: &Array1<T>) -> T {
350    a.iter()
351        .zip(b.iter())
352        .fold(T::zero(), |acc, (&x, &y)| acc + x * y)
353}
354
355/// L2 norm of an array.
356fn l2_norm<T: Float + Debug + Send + Sync + 'static>(arr: &Array1<T>) -> T {
357    arr.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt()
358}
359
360// ===========================================================================
361// Tests
362// ===========================================================================
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use scirs2_core::ndarray::Array1;
368
369    fn make_domain(name: &str, shared_dim: usize, specific_dim: usize) -> DomainKnowledge<f64> {
370        let shared = Array1::from_vec(
371            (0..shared_dim)
372                .map(|i| (i as f64 + 1.0) * if name.contains("nlp") { -1.0 } else { 1.0 })
373                .collect(),
374        );
375        let specific = Array1::from_elem(specific_dim, 0.5);
376        DomainKnowledge {
377            domain_name: name.to_string(),
378            shared_representation: shared,
379            domain_specific_params: specific,
380            performance_history: vec![0.8, 0.85, 0.9],
381        }
382    }
383
384    #[test]
385    fn test_register_domain() {
386        let mut engine = CrossDomainTransfer::<f64>::new(4);
387        let domain = make_domain("cv", 4, 8);
388        assert!(engine.register_domain(domain).is_ok());
389        assert_eq!(engine.get_registered_domains(), vec!["cv"]);
390
391        // Dimension mismatch should fail.
392        let bad = make_domain("bad", 3, 8);
393        assert!(engine.register_domain(bad).is_err());
394    }
395
396    #[test]
397    fn test_compute_domain_similarity() {
398        let mut engine = CrossDomainTransfer::<f64>::new(4);
399        engine
400            .register_domain(make_domain("cv", 4, 8))
401            .expect("register cv");
402        engine
403            .register_domain(make_domain("nlp", 4, 8))
404            .expect("register nlp");
405
406        let sim = engine
407            .compute_domain_similarity("cv", "nlp")
408            .expect("similarity");
409        // cv shared = [1,2,3,4], nlp shared = [-1,-2,-3,-4] => cosine = -1
410        assert!(
411            (sim - (-1.0)).abs() < 1e-10,
412            "Expected -1.0 cosine similarity, got {}",
413            sim
414        );
415
416        // Self-similarity should be 1.
417        let self_sim = engine
418            .compute_domain_similarity("cv", "cv")
419            .expect("self similarity");
420        assert!(
421            (self_sim - 1.0).abs() < 1e-10,
422            "Expected 1.0, got {}",
423            self_sim
424        );
425
426        // Unknown domain should error.
427        assert!(engine.compute_domain_similarity("cv", "rl").is_err());
428    }
429
430    #[test]
431    fn test_transfer_knowledge() {
432        let mut engine = CrossDomainTransfer::<f64>::new(4);
433
434        // Two similar domains (same sign shared repr)
435        let cv = DomainKnowledge {
436            domain_name: "cv".to_string(),
437            shared_representation: Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]),
438            domain_specific_params: Array1::from_elem(4, 1.0),
439            performance_history: vec![0.9],
440        };
441        let cv2 = DomainKnowledge {
442            domain_name: "cv2".to_string(),
443            shared_representation: Array1::from_vec(vec![1.1, 2.1, 3.1, 4.1]),
444            domain_specific_params: Array1::from_elem(4, 0.5),
445            performance_history: vec![0.7],
446        };
447        engine.register_domain(cv).expect("register cv");
448        engine.register_domain(cv2).expect("register cv2");
449
450        let result = engine.transfer("cv", "cv2").expect("transfer");
451        assert_eq!(result.source_domain, "cv");
452        assert_eq!(result.target_domain, "cv2");
453        assert!(
454            result.transferability_score > 0.9,
455            "Similar domains should have high transferability, got {}",
456            result.transferability_score
457        );
458        assert_eq!(result.transferred_params.len(), 4);
459        assert_eq!(engine.transfer_history().len(), 1);
460    }
461
462    #[test]
463    fn test_update_shared_representation() {
464        let mut engine = CrossDomainTransfer::<f64>::new(4);
465        let domain = DomainKnowledge {
466            domain_name: "cv".to_string(),
467            shared_representation: Array1::zeros(4),
468            domain_specific_params: Array1::zeros(4),
469            performance_history: vec![],
470        };
471        engine.register_domain(domain).expect("register");
472
473        let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
474        engine
475            .update_shared_representation("cv", &grads, 0.1)
476            .expect("update");
477
478        // shared = 0 - 0.1 * [1,2,3,4] = [-0.1, -0.2, -0.3, -0.4]
479        let shared = &engine.shared_representation().features;
480        assert!((shared[0] - (-0.1)).abs() < 1e-10);
481        assert!((shared[3] - (-0.4)).abs() < 1e-10);
482        assert_eq!(engine.shared_representation().version, 1);
483
484        // Unknown domain should error.
485        assert!(engine
486            .update_shared_representation("rl", &grads, 0.1)
487            .is_err());
488
489        // Dimension mismatch should error.
490        let bad_grads = Array1::from_vec(vec![1.0, 2.0]);
491        assert!(engine
492            .update_shared_representation("cv", &bad_grads, 0.1)
493            .is_err());
494    }
495
496    #[test]
497    fn test_transferability_matrix() {
498        let mut engine = CrossDomainTransfer::<f64>::new(3);
499
500        // Need at least 2 domains.
501        assert!(engine.get_transferability_matrix().is_err());
502
503        let d1 = DomainKnowledge {
504            domain_name: "a".to_string(),
505            shared_representation: Array1::from_vec(vec![1.0, 0.0, 0.0]),
506            domain_specific_params: Array1::zeros(2),
507            performance_history: vec![],
508        };
509        let d2 = DomainKnowledge {
510            domain_name: "b".to_string(),
511            shared_representation: Array1::from_vec(vec![0.0, 1.0, 0.0]),
512            domain_specific_params: Array1::zeros(2),
513            performance_history: vec![],
514        };
515        let d3 = DomainKnowledge {
516            domain_name: "c".to_string(),
517            shared_representation: Array1::from_vec(vec![1.0, 1.0, 0.0]),
518            domain_specific_params: Array1::zeros(2),
519            performance_history: vec![],
520        };
521        engine.register_domain(d1).expect("register a");
522        engine.register_domain(d2).expect("register b");
523        engine.register_domain(d3).expect("register c");
524
525        let matrix = engine.get_transferability_matrix().expect("matrix");
526        assert_eq!(matrix.shape(), &[3, 3]);
527
528        // Diagonal should be 1.
529        for i in 0..3 {
530            assert!(
531                (matrix[[i, i]] - 1.0).abs() < 1e-10,
532                "Diagonal [{},{}] should be 1.0, got {}",
533                i,
534                i,
535                matrix[[i, i]]
536            );
537        }
538
539        // a and b are orthogonal => similarity 0.
540        assert!(
541            matrix[[0, 1]].abs() < 1e-10,
542            "Orthogonal domains should have 0 similarity, got {}",
543            matrix[[0, 1]]
544        );
545
546        // Matrix should be symmetric.
547        for i in 0..3 {
548            for j in 0..3 {
549                assert!(
550                    (matrix[[i, j]] - matrix[[j, i]]).abs() < 1e-10,
551                    "Matrix should be symmetric at [{},{}]",
552                    i,
553                    j
554                );
555            }
556        }
557    }
558}