use super::QueryIntentType;
#[derive(Debug, Clone, Copy)]
pub struct LayerWeights {
pub l0: f32,
pub l1: f32,
pub l2: f32,
}
impl Default for LayerWeights {
fn default() -> Self {
Self {
l0: 0.2,
l1: 0.3,
l2: 0.5,
}
}
}
impl LayerWeights {
pub fn normalize(self) -> Self {
let total = self.l0 + self.l1 + self.l2;
if total <= 0.0 {
return Self::default();
}
Self {
l0: self.l0 / total,
l1: self.l1 / total,
l2: self.l2 / total,
}
}
}
pub fn weights_for_intent(intent_type: &QueryIntentType) -> LayerWeights {
match intent_type {
QueryIntentType::EntityLookup => LayerWeights {
l0: 0.1,
l1: 0.2,
l2: 0.7,
},
QueryIntentType::Factual => LayerWeights {
l0: 0.15,
l1: 0.25,
l2: 0.6,
},
QueryIntentType::Temporal => LayerWeights {
l0: 0.2,
l1: 0.35,
l2: 0.45,
},
QueryIntentType::Relational => LayerWeights {
l0: 0.2,
l1: 0.5,
l2: 0.3,
},
QueryIntentType::Search => LayerWeights {
l0: 0.35,
l1: 0.35,
l2: 0.3,
},
QueryIntentType::General => LayerWeights::default(),
}
}