use alloc::vec::Vec;
use core::cell::RefCell;
use serde::{Deserialize, Serialize};
use crate::{
corpus::{Corpus, CorpusId, Testcase},
inputs::{Input, UsesInput},
Error,
};
#[cfg(not(feature = "corpus_btreemap"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound = "I: serde::de::DeserializeOwned")]
pub struct TestcaseStorageItem<I>
where
I: Input,
{
pub testcase: RefCell<Testcase<I>>,
pub prev: Option<CorpusId>,
pub next: Option<CorpusId>,
}
#[cfg(not(feature = "corpus_btreemap"))]
pub type TestcaseStorageMap<I> = hashbrown::HashMap<CorpusId, TestcaseStorageItem<I>>;
#[cfg(feature = "corpus_btreemap")]
pub type TestcaseStorageMap<I> =
alloc::collections::btree_map::BTreeMap<CorpusId, RefCell<Testcase<I>>>;
#[derive(Default, Serialize, Deserialize, Clone, Debug)]
#[serde(bound = "I: serde::de::DeserializeOwned")]
pub struct TestcaseStorage<I>
where
I: Input,
{
pub map: TestcaseStorageMap<I>,
pub keys: Vec<CorpusId>,
progressive_idx: usize,
#[cfg(not(feature = "corpus_btreemap"))]
first_idx: Option<CorpusId>,
#[cfg(not(feature = "corpus_btreemap"))]
last_idx: Option<CorpusId>,
}
impl<I> UsesInput for TestcaseStorage<I>
where
I: Input,
{
type Input = I;
}
impl<I> TestcaseStorage<I>
where
I: Input,
{
fn insert_key(&mut self, id: CorpusId) {
if let Err(idx) = self.keys.binary_search(&id) {
self.keys.insert(idx, id);
}
}
fn remove_key(&mut self, id: CorpusId) {
if let Ok(idx) = self.keys.binary_search(&id) {
self.keys.remove(idx);
}
}
#[cfg(not(feature = "corpus_btreemap"))]
pub fn insert(&mut self, testcase: RefCell<Testcase<I>>) -> CorpusId {
let idx = CorpusId::from(self.progressive_idx);
self.progressive_idx += 1;
let prev = if let Some(last_idx) = self.last_idx {
self.map.get_mut(&last_idx).unwrap().next = Some(idx);
Some(last_idx)
} else {
None
};
if self.first_idx.is_none() {
self.first_idx = Some(idx);
}
self.last_idx = Some(idx);
self.insert_key(idx);
self.map.insert(
idx,
TestcaseStorageItem {
testcase,
prev,
next: None,
},
);
idx
}
#[cfg(feature = "corpus_btreemap")]
pub fn insert(&mut self, testcase: RefCell<Testcase<I>>) -> CorpusId {
let idx = CorpusId::from(self.progressive_idx);
self.progressive_idx += 1;
self.insert_key(idx);
self.map.insert(idx, testcase);
idx
}
#[cfg(not(feature = "corpus_btreemap"))]
pub fn replace(&mut self, idx: CorpusId, testcase: Testcase<I>) -> Option<Testcase<I>> {
if let Some(entry) = self.map.get_mut(&idx) {
Some(entry.testcase.replace(testcase))
} else {
None
}
}
#[cfg(feature = "corpus_btreemap")]
pub fn replace(&mut self, idx: CorpusId, testcase: Testcase<I>) -> Option<Testcase<I>> {
self.map.get_mut(&idx).map(|entry| entry.replace(testcase))
}
#[cfg(not(feature = "corpus_btreemap"))]
pub fn remove(&mut self, idx: CorpusId) -> Option<RefCell<Testcase<I>>> {
if let Some(item) = self.map.remove(&idx) {
self.remove_key(idx);
if let Some(prev) = item.prev {
self.map.get_mut(&prev).unwrap().next = item.next;
} else {
self.first_idx = item.next;
}
if let Some(next) = item.next {
self.map.get_mut(&next).unwrap().prev = item.prev;
} else {
self.last_idx = item.prev;
}
Some(item.testcase)
} else {
None
}
}
#[cfg(feature = "corpus_btreemap")]
pub fn remove(&mut self, idx: CorpusId) -> Option<RefCell<Testcase<I>>> {
self.remove_key(idx);
self.map.remove(&idx)
}
#[cfg(not(feature = "corpus_btreemap"))]
#[must_use]
pub fn get(&self, idx: CorpusId) -> Option<&RefCell<Testcase<I>>> {
self.map.get(&idx).as_ref().map(|x| &x.testcase)
}
#[cfg(feature = "corpus_btreemap")]
#[must_use]
pub fn get(&self, idx: CorpusId) -> Option<&RefCell<Testcase<I>>> {
self.map.get(&idx)
}
#[cfg(not(feature = "corpus_btreemap"))]
#[must_use]
fn next(&self, idx: CorpusId) -> Option<CorpusId> {
if let Some(item) = self.map.get(&idx) {
item.next
} else {
None
}
}
#[cfg(feature = "corpus_btreemap")]
#[must_use]
fn next(&self, idx: CorpusId) -> Option<CorpusId> {
let mut range = self
.map
.range((core::ops::Bound::Included(idx), core::ops::Bound::Unbounded));
if let Some((this_id, _)) = range.next() {
if idx != *this_id {
return None;
}
}
if let Some((next_id, _)) = range.next() {
Some(*next_id)
} else {
None
}
}
#[cfg(not(feature = "corpus_btreemap"))]
#[must_use]
fn prev(&self, idx: CorpusId) -> Option<CorpusId> {
if let Some(item) = self.map.get(&idx) {
item.prev
} else {
None
}
}
#[cfg(feature = "corpus_btreemap")]
#[must_use]
fn prev(&self, idx: CorpusId) -> Option<CorpusId> {
let mut range = self
.map
.range((core::ops::Bound::Unbounded, core::ops::Bound::Included(idx)));
if let Some((this_id, _)) = range.next_back() {
if idx != *this_id {
return None;
}
}
if let Some((prev_id, _)) = range.next_back() {
Some(*prev_id)
} else {
None
}
}
#[cfg(not(feature = "corpus_btreemap"))]
#[must_use]
fn first(&self) -> Option<CorpusId> {
self.first_idx
}
#[cfg(feature = "corpus_btreemap")]
#[must_use]
fn first(&self) -> Option<CorpusId> {
self.map.iter().next().map(|x| *x.0)
}
#[cfg(not(feature = "corpus_btreemap"))]
#[must_use]
fn last(&self) -> Option<CorpusId> {
self.last_idx
}
#[cfg(feature = "corpus_btreemap")]
#[must_use]
fn last(&self) -> Option<CorpusId> {
self.map.iter().next_back().map(|x| *x.0)
}
#[must_use]
pub fn new() -> Self {
Self {
map: TestcaseStorageMap::default(),
keys: vec![],
progressive_idx: 0,
#[cfg(not(feature = "corpus_btreemap"))]
first_idx: None,
#[cfg(not(feature = "corpus_btreemap"))]
last_idx: None,
}
}
}
#[derive(Default, Serialize, Deserialize, Clone, Debug)]
#[serde(bound = "I: serde::de::DeserializeOwned")]
pub struct InMemoryCorpus<I>
where
I: Input,
{
storage: TestcaseStorage<I>,
current: Option<CorpusId>,
}
impl<I> UsesInput for InMemoryCorpus<I>
where
I: Input,
{
type Input = I;
}
impl<I> Corpus for InMemoryCorpus<I>
where
I: Input,
{
#[inline]
fn count(&self) -> usize {
self.storage.map.len()
}
#[inline]
fn add(&mut self, testcase: Testcase<I>) -> Result<CorpusId, Error> {
Ok(self.storage.insert(RefCell::new(testcase)))
}
#[inline]
fn replace(&mut self, idx: CorpusId, testcase: Testcase<I>) -> Result<Testcase<I>, Error> {
self.storage
.replace(idx, testcase)
.ok_or_else(|| Error::key_not_found(format!("Index {idx} not found")))
}
#[inline]
fn remove(&mut self, idx: CorpusId) -> Result<Testcase<I>, Error> {
self.storage
.remove(idx)
.map(|x| x.take())
.ok_or_else(|| Error::key_not_found(format!("Index {idx} not found")))
}
#[inline]
fn get(&self, idx: CorpusId) -> Result<&RefCell<Testcase<I>>, Error> {
self.storage
.get(idx)
.ok_or_else(|| Error::key_not_found(format!("Index {idx} not found")))
}
#[inline]
fn current(&self) -> &Option<CorpusId> {
&self.current
}
#[inline]
fn current_mut(&mut self) -> &mut Option<CorpusId> {
&mut self.current
}
#[inline]
fn next(&self, idx: CorpusId) -> Option<CorpusId> {
self.storage.next(idx)
}
#[inline]
fn prev(&self, idx: CorpusId) -> Option<CorpusId> {
self.storage.prev(idx)
}
#[inline]
fn first(&self) -> Option<CorpusId> {
self.storage.first()
}
#[inline]
fn last(&self) -> Option<CorpusId> {
self.storage.last()
}
#[inline]
fn nth(&self, nth: usize) -> CorpusId {
self.storage.keys[nth]
}
}
impl<I> InMemoryCorpus<I>
where
I: Input,
{
#[must_use]
pub fn new() -> Self {
Self {
storage: TestcaseStorage::new(),
current: None,
}
}
}
#[cfg(feature = "python")]
pub mod pybind {
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
use crate::{
corpus::{pybind::PythonCorpus, InMemoryCorpus},
inputs::BytesInput,
};
#[pyclass(unsendable, name = "InMemoryCorpus")]
#[allow(clippy::unsafe_derive_deserialize)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PythonInMemoryCorpus {
pub inner: InMemoryCorpus<BytesInput>,
}
#[pymethods]
impl PythonInMemoryCorpus {
#[new]
fn new() -> Self {
Self {
inner: InMemoryCorpus::new(),
}
}
fn as_corpus(slf: Py<Self>) -> PythonCorpus {
PythonCorpus::new_in_memory(slf)
}
}
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PythonInMemoryCorpus>()?;
Ok(())
}
}