1use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::{Array1, Array2, ScalarOperand, Zip};
12use scirs2_core::numeric::Float;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
25pub struct DomainKnowledge<T: Float + Debug + Send + Sync + 'static> {
26 pub domain_name: String,
28 pub shared_representation: Array1<T>,
30 pub domain_specific_params: Array1<T>,
32 pub performance_history: Vec<T>,
34}
35
36#[derive(Debug, Clone)]
42pub struct SharedRepresentation<T: Float + Debug + Send + Sync + 'static> {
43 pub features: Array1<T>,
45 pub dimension: usize,
47 pub version: usize,
49}
50
51impl<T: Float + Debug + Send + Sync + 'static> SharedRepresentation<T> {
52 pub fn new(dimension: usize) -> Self {
54 Self {
55 features: Array1::<T>::zeros(dimension),
56 dimension,
57 version: 0,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
68pub struct TransferResult<T: Float + Debug + Send + Sync + 'static> {
69 pub transferred_params: Array1<T>,
71 pub transferability_score: T,
73 pub source_domain: String,
75 pub target_domain: String,
77}
78
79#[derive(Debug)]
89pub struct CrossDomainTransfer<T: Float + Debug + Send + Sync + 'static> {
90 domains: HashMap<String, DomainKnowledge<T>>,
92 shared_repr: SharedRepresentation<T>,
94 transfer_history: Vec<TransferResult<T>>,
96}
97
98impl<T: Float + Debug + Send + Sync + 'static + ScalarOperand> CrossDomainTransfer<T> {
99 pub fn new(shared_dim: usize) -> Self {
101 Self {
102 domains: HashMap::new(),
103 shared_repr: SharedRepresentation::new(shared_dim),
104 transfer_history: Vec::new(),
105 }
106 }
107
108 pub fn transfer_history(&self) -> &[TransferResult<T>] {
110 &self.transfer_history
111 }
112
113 pub fn shared_representation(&self) -> &SharedRepresentation<T> {
115 &self.shared_repr
116 }
117
118 pub fn register_domain(&mut self, knowledge: DomainKnowledge<T>) -> Result<()> {
128 if knowledge.shared_representation.len() != self.shared_repr.dimension {
129 return Err(OptimError::InvalidConfig(format!(
130 "Shared representation dimension mismatch: expected {}, got {}",
131 self.shared_repr.dimension,
132 knowledge.shared_representation.len()
133 )));
134 }
135 self.domains
136 .insert(knowledge.domain_name.clone(), knowledge);
137 Ok(())
138 }
139
140 pub fn get_registered_domains(&self) -> Vec<&str> {
142 let mut names: Vec<&str> = self.domains.keys().map(|s| s.as_str()).collect();
143 names.sort();
144 names
145 }
146
147 pub fn compute_domain_similarity(&self, source: &str, target: &str) -> Result<T> {
159 let src = self.get_domain(source)?;
160 let tgt = self.get_domain(target)?;
161
162 let dot = dot_product(&src.shared_representation, &tgt.shared_representation);
163 let norm_src = l2_norm(&src.shared_representation);
164 let norm_tgt = l2_norm(&tgt.shared_representation);
165
166 let denom = norm_src * norm_tgt;
167 if denom <= T::zero() {
168 return Ok(T::zero());
169 }
170 Ok(dot / denom)
171 }
172
173 pub fn transfer(&mut self, source: &str, target: &str) -> Result<TransferResult<T>> {
192 let src = self.get_domain(source)?.clone();
194 let tgt = self.get_domain(target)?.clone();
195
196 if src.domain_specific_params.len() != tgt.domain_specific_params.len() {
197 return Err(OptimError::ComputationError(format!(
198 "Domain-specific parameter dimension mismatch: source {} vs target {}",
199 src.domain_specific_params.len(),
200 tgt.domain_specific_params.len()
201 )));
202 }
203
204 let similarity = self.compute_domain_similarity(source, target)?;
205 let two = T::from(2.0).unwrap_or_else(|| T::one() + T::one());
207 let transferability = (similarity + T::one()) / two;
208
209 let dim = tgt.domain_specific_params.len();
211 let mut transferred = Array1::<T>::zeros(dim);
212
213 let shared_dim = src
217 .shared_representation
218 .len()
219 .min(tgt.shared_representation.len());
220 let mut shared_diff = Array1::<T>::zeros(dim);
221 for i in 0..shared_dim.min(dim) {
222 shared_diff[i] = src.shared_representation[i] - tgt.shared_representation[i];
223 }
224
225 Zip::from(&mut transferred)
226 .and(&tgt.domain_specific_params)
227 .and(&shared_diff)
228 .for_each(|out, &tgt_p, &sd| {
229 *out = tgt_p + transferability * sd;
230 });
231
232 let result = TransferResult {
233 transferred_params: transferred,
234 transferability_score: transferability,
235 source_domain: source.to_string(),
236 target_domain: target.to_string(),
237 };
238
239 self.transfer_history.push(result.clone());
240 Ok(result)
241 }
242
243 pub fn update_shared_representation(
257 &mut self,
258 domain_name: &str,
259 gradients: &Array1<T>,
260 lr: T,
261 ) -> Result<()> {
262 if gradients.len() != self.shared_repr.dimension {
263 return Err(OptimError::ComputationError(format!(
264 "Gradient dimension {} does not match shared dimension {}",
265 gradients.len(),
266 self.shared_repr.dimension
267 )));
268 }
269
270 if !self.domains.contains_key(domain_name) {
272 return Err(OptimError::InvalidState(format!(
273 "Domain '{}' is not registered",
274 domain_name
275 )));
276 }
277
278 Zip::from(&mut self.shared_repr.features)
280 .and(gradients)
281 .for_each(|f, &g| {
282 *f = *f - lr * g;
283 });
284 self.shared_repr.version += 1;
285
286 if let Some(domain) = self.domains.get_mut(domain_name) {
288 Zip::from(&mut domain.shared_representation)
289 .and(&self.shared_repr.features)
290 .for_each(|d, &s| {
291 *d = s;
292 });
293 }
294
295 Ok(())
296 }
297
298 pub fn get_transferability_matrix(&self) -> Result<Array2<T>> {
311 let names = self.get_registered_domains();
312 let n = names.len();
313 if n < 2 {
314 return Err(OptimError::InsufficientData(
315 "Need at least 2 registered domains to build a transferability matrix".into(),
316 ));
317 }
318
319 let mut matrix = Array2::<T>::zeros((n, n));
320 for (i, &src) in names.iter().enumerate() {
321 for (j, &tgt) in names.iter().enumerate() {
322 if i == j {
323 matrix[[i, j]] = T::one();
324 } else {
325 matrix[[i, j]] = self.compute_domain_similarity(src, tgt)?;
326 }
327 }
328 }
329 Ok(matrix)
330 }
331
332 fn get_domain(&self, name: &str) -> Result<&DomainKnowledge<T>> {
338 self.domains
339 .get(name)
340 .ok_or_else(|| OptimError::InvalidState(format!("Domain '{}' is not registered", name)))
341 }
342}
343
344fn dot_product<T: Float + Debug + Send + Sync + 'static>(a: &Array1<T>, b: &Array1<T>) -> T {
350 a.iter()
351 .zip(b.iter())
352 .fold(T::zero(), |acc, (&x, &y)| acc + x * y)
353}
354
355fn l2_norm<T: Float + Debug + Send + Sync + 'static>(arr: &Array1<T>) -> T {
357 arr.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt()
358}
359
360#[cfg(test)]
365mod tests {
366 use super::*;
367 use scirs2_core::ndarray::Array1;
368
369 fn make_domain(name: &str, shared_dim: usize, specific_dim: usize) -> DomainKnowledge<f64> {
370 let shared = Array1::from_vec(
371 (0..shared_dim)
372 .map(|i| (i as f64 + 1.0) * if name.contains("nlp") { -1.0 } else { 1.0 })
373 .collect(),
374 );
375 let specific = Array1::from_elem(specific_dim, 0.5);
376 DomainKnowledge {
377 domain_name: name.to_string(),
378 shared_representation: shared,
379 domain_specific_params: specific,
380 performance_history: vec![0.8, 0.85, 0.9],
381 }
382 }
383
384 #[test]
385 fn test_register_domain() {
386 let mut engine = CrossDomainTransfer::<f64>::new(4);
387 let domain = make_domain("cv", 4, 8);
388 assert!(engine.register_domain(domain).is_ok());
389 assert_eq!(engine.get_registered_domains(), vec!["cv"]);
390
391 let bad = make_domain("bad", 3, 8);
393 assert!(engine.register_domain(bad).is_err());
394 }
395
396 #[test]
397 fn test_compute_domain_similarity() {
398 let mut engine = CrossDomainTransfer::<f64>::new(4);
399 engine
400 .register_domain(make_domain("cv", 4, 8))
401 .expect("register cv");
402 engine
403 .register_domain(make_domain("nlp", 4, 8))
404 .expect("register nlp");
405
406 let sim = engine
407 .compute_domain_similarity("cv", "nlp")
408 .expect("similarity");
409 assert!(
411 (sim - (-1.0)).abs() < 1e-10,
412 "Expected -1.0 cosine similarity, got {}",
413 sim
414 );
415
416 let self_sim = engine
418 .compute_domain_similarity("cv", "cv")
419 .expect("self similarity");
420 assert!(
421 (self_sim - 1.0).abs() < 1e-10,
422 "Expected 1.0, got {}",
423 self_sim
424 );
425
426 assert!(engine.compute_domain_similarity("cv", "rl").is_err());
428 }
429
430 #[test]
431 fn test_transfer_knowledge() {
432 let mut engine = CrossDomainTransfer::<f64>::new(4);
433
434 let cv = DomainKnowledge {
436 domain_name: "cv".to_string(),
437 shared_representation: Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]),
438 domain_specific_params: Array1::from_elem(4, 1.0),
439 performance_history: vec![0.9],
440 };
441 let cv2 = DomainKnowledge {
442 domain_name: "cv2".to_string(),
443 shared_representation: Array1::from_vec(vec![1.1, 2.1, 3.1, 4.1]),
444 domain_specific_params: Array1::from_elem(4, 0.5),
445 performance_history: vec![0.7],
446 };
447 engine.register_domain(cv).expect("register cv");
448 engine.register_domain(cv2).expect("register cv2");
449
450 let result = engine.transfer("cv", "cv2").expect("transfer");
451 assert_eq!(result.source_domain, "cv");
452 assert_eq!(result.target_domain, "cv2");
453 assert!(
454 result.transferability_score > 0.9,
455 "Similar domains should have high transferability, got {}",
456 result.transferability_score
457 );
458 assert_eq!(result.transferred_params.len(), 4);
459 assert_eq!(engine.transfer_history().len(), 1);
460 }
461
462 #[test]
463 fn test_update_shared_representation() {
464 let mut engine = CrossDomainTransfer::<f64>::new(4);
465 let domain = DomainKnowledge {
466 domain_name: "cv".to_string(),
467 shared_representation: Array1::zeros(4),
468 domain_specific_params: Array1::zeros(4),
469 performance_history: vec![],
470 };
471 engine.register_domain(domain).expect("register");
472
473 let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
474 engine
475 .update_shared_representation("cv", &grads, 0.1)
476 .expect("update");
477
478 let shared = &engine.shared_representation().features;
480 assert!((shared[0] - (-0.1)).abs() < 1e-10);
481 assert!((shared[3] - (-0.4)).abs() < 1e-10);
482 assert_eq!(engine.shared_representation().version, 1);
483
484 assert!(engine
486 .update_shared_representation("rl", &grads, 0.1)
487 .is_err());
488
489 let bad_grads = Array1::from_vec(vec![1.0, 2.0]);
491 assert!(engine
492 .update_shared_representation("cv", &bad_grads, 0.1)
493 .is_err());
494 }
495
496 #[test]
497 fn test_transferability_matrix() {
498 let mut engine = CrossDomainTransfer::<f64>::new(3);
499
500 assert!(engine.get_transferability_matrix().is_err());
502
503 let d1 = DomainKnowledge {
504 domain_name: "a".to_string(),
505 shared_representation: Array1::from_vec(vec![1.0, 0.0, 0.0]),
506 domain_specific_params: Array1::zeros(2),
507 performance_history: vec![],
508 };
509 let d2 = DomainKnowledge {
510 domain_name: "b".to_string(),
511 shared_representation: Array1::from_vec(vec![0.0, 1.0, 0.0]),
512 domain_specific_params: Array1::zeros(2),
513 performance_history: vec![],
514 };
515 let d3 = DomainKnowledge {
516 domain_name: "c".to_string(),
517 shared_representation: Array1::from_vec(vec![1.0, 1.0, 0.0]),
518 domain_specific_params: Array1::zeros(2),
519 performance_history: vec![],
520 };
521 engine.register_domain(d1).expect("register a");
522 engine.register_domain(d2).expect("register b");
523 engine.register_domain(d3).expect("register c");
524
525 let matrix = engine.get_transferability_matrix().expect("matrix");
526 assert_eq!(matrix.shape(), &[3, 3]);
527
528 for i in 0..3 {
530 assert!(
531 (matrix[[i, i]] - 1.0).abs() < 1e-10,
532 "Diagonal [{},{}] should be 1.0, got {}",
533 i,
534 i,
535 matrix[[i, i]]
536 );
537 }
538
539 assert!(
541 matrix[[0, 1]].abs() < 1e-10,
542 "Orthogonal domains should have 0 similarity, got {}",
543 matrix[[0, 1]]
544 );
545
546 for i in 0..3 {
548 for j in 0..3 {
549 assert!(
550 (matrix[[i, j]] - matrix[[j, i]]).abs() < 1e-10,
551 "Matrix should be symmetric at [{},{}]",
552 i,
553 j
554 );
555 }
556 }
557 }
558}