use crate::phi2::config::Phi2Config;
use crate::phi2::model::Phi2ForCausalLM;
use trustformers_core::errors::Result;
use trustformers_core::tensor::Tensor;
pub struct Phi2CausalLMOutput {
pub logits: Tensor,
}
pub struct Phi2ForCodeGeneration {
inner: Phi2ForCausalLM,
}
impl Phi2ForCodeGeneration {
pub fn new(config: Phi2Config) -> Result<Self> {
let inner = Phi2ForCausalLM::new(config)?;
Ok(Self { inner })
}
pub fn config(&self) -> &Phi2Config {
self.inner.config()
}
pub fn parameter_count(&self) -> usize {
self.inner.parameter_count()
}
pub fn forward(&self, input_ids: Vec<u32>) -> Result<Phi2CausalLMOutput> {
let logits = self.inner.forward(input_ids)?;
Ok(Phi2CausalLMOutput { logits })
}
pub fn greedy_next_token(&self, logits: &Tensor) -> Result<u32> {
match logits {
Tensor::F32(arr) => {
let flat: Vec<f32> = arr.iter().copied().collect();
let best = flat
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
.unwrap_or(0);
Ok(best)
},
_ => Ok(0),
}
}
pub fn generate_code(&self, prompt_ids: Vec<u32>) -> Result<String> {
let output = self.forward(prompt_ids)?;
let next_token = self.greedy_next_token(&output.logits)?;
Ok(format!("# generated code (next_token={next_token})"))
}
}