use half::f16;
use std::fmt::{Debug, Formatter};
use std::ops::{Deref, DerefMut};
use ndarray::{s, Array, Array2, Axis, Ix2, Ix3, IxDyn};
use num_traits::FloatConst;
use ort::tensor::ArrayExtensions;
use ort::value::DynValue;
use rand::distributions::WeightedIndex;
use rand::{thread_rng, Rng};
pub struct Logits(Array2<f32>);
impl TryFrom<DynValue> for Logits {
type Error = ort::Error;
fn try_from(value: DynValue) -> Result<Self, Self::Error> {
(&value).try_into()
}
}
impl TryFrom<&DynValue> for Logits {
type Error = ort::Error;
fn try_from(value: &DynValue) -> Result<Self, Self::Error> {
let arr = value.try_extract_tensor::<f32>()?.into_owned();
let arr = arr.into_dimensionality::<Ix2>().expect("Expected dim 2");
Ok(Self(arr))
}
}
impl From<Array<f32, IxDyn>> for Logits {
fn from(value: Array<f32, IxDyn>) -> Self {
let arr = value.into_dimensionality::<Ix2>().expect("Expected dim 2");
Self(arr)
}
}
impl Deref for Logits {
type Target = Array2<f32>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Logits {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Debug for Logits {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl Logits {
pub fn from_3d_dyn_value(value: &DynValue) -> ort::Result<Self> {
let arr = if let Ok(res) = value.try_extract_tensor::<f32>() {
res.into_owned()
} else {
let arr = value.try_extract_tensor::<f16>()?;
arr.mapv(f32::from)
};
let arr = arr
.into_dimensionality::<Ix3>()
.expect("Expected 3 dimensions");
let arr = arr.remove_axis(Axis(1));
Ok(Self(arr))
}
pub fn apply_free_guidance(self, guidance_scale: usize) -> Self {
if self.0.dim().0 % 2 != 0 {
panic!("In order to apply free guidance to the logits, the first size of the first dimension must be even")
}
let unguided_bsz = self.0.dim().0 / 2;
let cond_logits = self.0.slice(s![0..unguided_bsz, ..]);
let uncond_logits = self.0.slice(s![unguided_bsz.., ..]);
Self((cond_logits.into_owned() - uncond_logits) * guidance_scale as f32 + uncond_logits)
}
pub fn sample(&self, k: usize) -> Vec<(i64, f32)> {
let mut result = vec![];
let softmax_logits = self.0.softmax(Axis(1));
for batch in softmax_logits.axis_iter(Axis(0)) {
let k = k.min(batch.len());
let mut softmax_logits_batch = batch
.iter()
.enumerate()
.map(|(i, e)| (i as i64, *e))
.collect::<Vec<_>>();
softmax_logits_batch.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("Could not compare two numbers in order to sort them")
});
softmax_logits_batch = softmax_logits_batch[0..k].to_vec();
let distribution = WeightedIndex::new(softmax_logits_batch.iter().map(|e| e.1))
.expect("Could not create WeightedIndex distribution");
let (idx, softmax_prob) = softmax_logits_batch[thread_rng().sample(distribution)];
result.push((idx, softmax_prob.log(f32::E())))
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn free_guidance() {
let logits = Logits::from(Array::from(vec![[10., -1., 3.], [-1., 1., 11.]]).into_dyn());
let logits = logits.apply_free_guidance(3);
assert_eq!(logits.shape(), &[1, 3]);
}
}