use std::{
marker::PhantomData,
ptr::{self, NonNull},
};
use crate::{header::Header, Pointer};
pub(crate) struct Node<T> {
pub next: Pointer<Self>,
pub prev: Pointer<Self>,
pub data: T,
}
pub(crate) struct LinkedList<T> {
head: Pointer<Node<T>>,
tail: Pointer<Node<T>>,
len: usize,
marker: PhantomData<T>,
}
pub(crate) struct Iter<T> {
current: Pointer<Node<T>>,
len: usize,
marker: PhantomData<T>,
}
impl<T> LinkedList<T> {
pub const fn new() -> Self {
Self {
head: None,
tail: None,
len: 0,
marker: PhantomData,
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn first(&self) -> Pointer<Header<T>> {
self.head
}
#[cfg(test)]
#[inline]
pub fn last(&self) -> Pointer<Header<T>> {
self.tail
}
pub unsafe fn append(&mut self, data: T, address: NonNull<u8>) -> NonNull<Header<T>> {
let node = address.cast();
ptr::write(
node.as_ptr(),
Node {
prev: self.tail,
next: None,
data,
},
);
if let Some(mut tail) = self.tail {
tail.as_mut().next = Some(node);
} else {
self.head = Some(node);
}
self.tail = Some(node);
self.len += 1;
node
}
pub unsafe fn insert_after(
&mut self,
mut node: NonNull<Node<T>>,
data: T,
address: NonNull<u8>,
) -> NonNull<Header<T>> {
let new_node = address.cast();
ptr::write(
new_node.as_ptr(),
Node {
prev: Some(node),
next: node.as_ref().next,
data,
},
);
if node == self.tail.unwrap() {
self.tail = Some(new_node);
} else {
node.as_ref().next.unwrap().as_mut().prev = Some(new_node);
}
node.as_mut().next = Some(new_node);
self.len += 1;
new_node
}
pub unsafe fn remove(&mut self, mut node: NonNull<Node<T>>) {
if self.len == 1 {
self.head = None;
self.tail = None;
} else if node == self.head.unwrap() {
node.as_mut().next.unwrap().as_mut().prev = None;
self.head = node.as_ref().next;
} else if node == self.tail.unwrap() {
node.as_mut().prev.unwrap().as_mut().next = None;
self.tail = node.as_ref().prev;
} else {
let mut next = node.as_ref().next.unwrap();
let mut prev = node.as_ref().prev.unwrap();
prev.as_mut().next = Some(next);
next.as_mut().prev = Some(prev);
}
self.len -= 1;
}
pub fn iter(&self) -> Iter<T> {
Iter {
current: self.head,
len: self.len,
marker: PhantomData,
}
}
}
impl<T> Iterator for Iter<T> {
type Item = NonNull<Node<T>>;
fn next(&mut self) -> Option<Self::Item> {
self.current.map(|node| unsafe {
self.current = node.as_ref().next;
self.len -= 1;
node
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len, Some(self.len))
}
}
impl<T> IntoIterator for &LinkedList<T> {
type Item = NonNull<Node<T>>;
type IntoIter = Iter<T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use std::{mem, ptr::NonNull};
use super::*;
use crate::platform;
#[test]
fn linked_list_operations() {
unsafe {
let mut list: LinkedList<u8> = LinkedList::new();
let region = platform::request_memory(platform::page_size()).unwrap();
let size = mem::size_of::<Node<u8>>();
let node1 = list.append(1, region);
let node2 = list.append(2, NonNull::new_unchecked(region.as_ptr().add(size)));
let node3 = list.append(3, NonNull::new_unchecked(region.as_ptr().add(size * 2)));
assert_eq!(list.len, 3);
assert_eq!(node1.as_ref().data, 1);
assert_eq!(node2.as_ref().data, 2);
assert_eq!(node3.as_ref().data, 3);
assert_eq!(list.head, Some(node1));
assert_eq!(list.tail, Some(node3));
assert_eq!(node1.as_ref().next, Some(node2));
assert_eq!(node1.as_ref().prev, None);
assert_eq!(node2.as_ref().next, Some(node3));
assert_eq!(node2.as_ref().prev, Some(node1));
assert_eq!(node3.as_ref().next, None);
assert_eq!(node3.as_ref().prev, Some(node2));
let node4 = list.insert_after(
node2,
4,
NonNull::new_unchecked(region.as_ptr().add(size * 3)),
);
assert_eq!(list.len, 4);
assert_eq!(list.tail, Some(node3));
assert_eq!(node4.as_ref().data, 4);
assert_eq!(node4.as_ref().next, Some(node3));
assert_eq!(node4.as_ref().prev, Some(node2));
assert_eq!(node2.as_ref().next, Some(node4));
assert_eq!(node2.as_ref().prev, Some(node1));
assert_eq!(node3.as_ref().next, None);
assert_eq!(node3.as_ref().prev, Some(node4));
list.remove(node4);
assert_eq!(list.len, 3);
assert_eq!(node2.as_ref().next, Some(node3));
assert_eq!(node2.as_ref().prev, Some(node1));
assert_eq!(node3.as_ref().next, None);
assert_eq!(node3.as_ref().prev, Some(node2));
list.remove(node3);
assert_eq!(list.len, 2);
assert_eq!(Some(node1), list.head);
assert_eq!(Some(node2), list.tail);
assert_eq!(node2.as_ref().next, None);
assert_eq!(node2.as_ref().prev, Some(node1));
list.remove(node1);
assert_eq!(list.len, 1);
assert_eq!(Some(node2), list.head);
assert_eq!(Some(node2), list.tail);
assert_eq!(node2.as_ref().next, None);
assert_eq!(node2.as_ref().prev, None);
list.remove(node2);
assert_eq!(list.tail, None);
assert_eq!(list.head, None);
assert_eq!(list.len, 0);
platform::return_memory(region, platform::page_size());
}
}
}