use crate::codellama::config::CodeLlamaConfig;
use crate::codellama::model::CodeLlamaForCausalLM;
use trustformers_core::errors::Result;
use trustformers_core::tensor::Tensor;
pub struct CodeLMOutput {
pub logits: Tensor,
}
pub struct CodeLlamaCompletion {
inner: CodeLlamaForCausalLM,
}
impl CodeLlamaCompletion {
pub fn new(config: CodeLlamaConfig) -> Result<Self> {
let inner = CodeLlamaForCausalLM::new(config)?;
Ok(Self { inner })
}
pub fn config(&self) -> &CodeLlamaConfig {
self.inner.config()
}
pub fn complete(&self, input_ids: Vec<u32>) -> Result<CodeLMOutput> {
let logits = self.inner.forward(input_ids)?;
Ok(CodeLMOutput { logits })
}
}
pub struct CodeLlamaInfilling {
inner: CodeLlamaForCausalLM,
pub infilling_enabled: bool,
}
impl CodeLlamaInfilling {
pub fn new(config: CodeLlamaConfig) -> Result<Self> {
let infilling_enabled = config.infilling;
let inner = CodeLlamaForCausalLM::new(config)?;
Ok(Self {
inner,
infilling_enabled,
})
}
pub fn config(&self) -> &CodeLlamaConfig {
self.inner.config()
}
pub fn infill(
&self,
_prefix_ids: &[u32],
_suffix_ids: &[u32],
merged_ids: Vec<u32>,
) -> Result<CodeLMOutput> {
let logits = self.inner.forward(merged_ids)?;
Ok(CodeLMOutput { logits })
}
}
pub struct CodeLlamaRepoLevel {
inner: CodeLlamaForCausalLM,
pub repo_context_limit: usize,
}
impl CodeLlamaRepoLevel {
pub fn new(config: CodeLlamaConfig) -> Result<Self> {
let repo_context_limit = config.effective_max_context();
let inner = CodeLlamaForCausalLM::new(config)?;
Ok(Self {
inner,
repo_context_limit,
})
}
pub fn config(&self) -> &CodeLlamaConfig {
self.inner.config()
}
pub fn forward_with_repo_context(&self, input_ids: Vec<u32>) -> Result<CodeLMOutput> {
let logits = self.inner.forward(input_ids)?;
Ok(CodeLMOutput { logits })
}
}