use alloc::{vec, vec::Vec};
use miden_core::{EventName, Felt, FieldElement, LexicographicWord, Word};
use miden_processor::{AdviceMutation, EventError, MemoryError, ProcessState};
pub const LOWERBOUND_ARRAY_EVENT_NAME: EventName =
EventName::new("stdlib::collections::sorted_array::lowerbound_array");
pub const LOWERBOUND_KEY_VALUE_EVENT_NAME: EventName =
EventName::new("stdlib::collections::sorted_array::lowerbound_key_value");
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum KeySize {
Full,
Half,
}
pub fn handle_lowerbound_array(process: &ProcessState) -> Result<Vec<AdviceMutation>, EventError> {
push_lowerbound_result(process, 4, KeySize::Full)
}
pub fn handle_lowerbound_key_value(
process: &ProcessState,
) -> Result<Vec<AdviceMutation>, EventError> {
let use_full_key = process.get_stack_item(7);
let key_size = match use_full_key.as_int() {
0 => KeySize::Half,
1 => KeySize::Full,
_ => {
return Err(EventError::from(alloc::format!(
"use_full_key must be 0 or 1, was {use_full_key}"
)));
},
};
push_lowerbound_result(process, 8, key_size)
}
const KEY_OFFSET: usize = 1;
const START_ADDR_OFFSET: usize = 5;
const END_ADDR_OFFSET: usize = 6;
fn push_lowerbound_result(
process: &ProcessState,
stride: u32,
key_size: KeySize,
) -> Result<Vec<AdviceMutation>, EventError> {
assert!(stride == 4 || stride == 8);
let key = LexicographicWord::new(process.get_stack_word_be(KEY_OFFSET));
let addr_range = process.get_mem_addr_range(START_ADDR_OFFSET, END_ADDR_OFFSET)?;
if addr_range.start % 4 != 0 {
return Err(MemoryError::unaligned_word_access(
addr_range.start,
process.ctx(),
Felt::from(process.clk()),
&(),
)
.into());
}
if (addr_range.end - addr_range.start) % stride != 0 {
if stride == 4 {
return Err(
SortedArrayError::InvalidArrayRange { size: addr_range.len() as u32 }.into()
);
} else {
return Err(
SortedArrayError::InvalidKeyValueRange { size: addr_range.len() as u32 }.into()
);
}
}
if addr_range.is_empty() {
return Ok(vec![AdviceMutation::extend_stack(vec![
Felt::from(false),
Felt::from(addr_range.end),
])]);
}
let get_word = {
|addr: u32| {
process
.get_mem_word(process.ctx(), addr)
.map(|word| word_to_search_key(word.unwrap_or_default(), key_size))
}
};
let mut was_key_found = false;
let mut result = None;
let mut previous_word = get_word(addr_range.start)?;
if previous_word >= key {
was_key_found = previous_word == key;
result = Some(addr_range.start);
}
for addr in addr_range.clone().step_by(stride as usize).skip(1) {
let word = get_word(addr)?;
if word < previous_word {
return Err(SortedArrayError::NotAscendingOrder {
index: addr,
value: word.into(),
predecessor: previous_word.into(),
}
.into());
}
if word >= key && result.is_none() {
was_key_found = word == key;
result = Some(addr);
}
previous_word = word;
}
Ok(vec![AdviceMutation::extend_stack(vec![
Felt::from(was_key_found),
Felt::from(result.unwrap_or(addr_range.end)),
])])
}
fn word_to_search_key(mut word: Word, key_size: KeySize) -> LexicographicWord {
match key_size {
KeySize::Full => LexicographicWord::new(word),
KeySize::Half => {
word[0] = Felt::ZERO;
word[1] = Felt::ZERO;
LexicographicWord::new(word)
},
}
}
#[derive(Debug, thiserror::Error)]
pub enum SortedArrayError {
#[error("element at index {index} ({value}) is smaller than the predecessor ({predecessor})")]
NotAscendingOrder {
index: u32,
value: Word,
predecessor: Word,
},
#[error("array size must be divisible by 4, but was {size}")]
InvalidArrayRange { size: u32 },
#[error("key-value array must have size divisible by 4 or 8, but was {size}")]
InvalidKeyValueRange { size: u32 },
}