1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use std::hash::{Hash, Hasher};

use pyo3::class::{PyObjectProtocol, PySequenceProtocol};
use pyo3::prelude::{pyclass, pyfunction, pymethods, pyproto, PyModule, PyObject, PyResult};
use pyo3::types::PyTuple;
use pyo3::{
    exceptions, wrap_pyfunction, ObjectProtocol, PyAny, PyCell, PyErr, PyIterProtocol, PyRefMut,
    Python,
};

use crate::object::{extract_py_object, Object};

type RpdsList = rpds::List<Object>;

#[pyclass]
#[derive(Default)]
pub struct List {
    value: RpdsList,
}

impl List {
    #[must_use]
    pub fn new() -> Self {
        List {
            value: RpdsList::new(),
        }
    }
}

#[pymethods]
impl List {
    pub fn push_front(&self, py_object: PyObject) -> PyResult<Self> {
        let new_self = Self {
            value: self.value.push_front(Object::new(py_object)),
        };
        Ok(new_self)
    }

    pub fn reverse(&self) -> PyResult<Self> {
        let reversed = Self {
            value: self.value.reverse(),
        };
        Ok(reversed)
    }

    #[getter]
    pub fn first(&self) -> PyResult<PyObject> {
        extract_py_object(self.value.first())
    }
}

impl Hash for List {
    fn hash<H: Hasher>(&self, state: &mut H) {
        // Add the hash of length so that if two collections are added one after the other it doesn't
        // hash to the same thing as a single collection with the same elements in the same order.
        self.value.len().hash(state);
        for element in self.value.iter() {
            element.hash(state);
        }
    }
}

#[pyproto]
impl PySequenceProtocol for List {
    fn __len__(&self) -> PyResult<usize> {
        let len = self.value.len();
        Ok(len)
    }
}

#[pyproto]
impl PyIterProtocol for List {
    fn __iter__(slf: PyRefMut<Self>) -> PyResult<crate::iterators::PyObjectIterator> {
        let mut elements = std::vec::Vec::new();
        for element in slf.value.iter() {
            elements.push(extract_py_object(Some(element))?)
        }

        Ok(crate::iterators::PyObjectIterator::new(
            elements.into_iter(),
        ))
    }
}

py_object_protocol!(List);

impl std::fmt::Display for List {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "plist")
    }
}

#[pyfunction(args = "*")]
fn plist(args: &PyTuple) -> PyResult<List> {
    let mut list = List::new();
    if args.is_empty() {
        return Ok(list);
    } else if args.len() > 1 {
        return Err(PyErr::new::<exceptions::ValueError, _>(
            "Incorrect number of arguments!!",
        ));
    }

    let iterator = args.get_item(0).as_ref().iter().unwrap();
    for element in iterator {
        let element = element.unwrap().extract::<PyObject>()?;
        list = list.push_front(element)?;
    }
    Ok(list)
}

#[pyfunction(args = "*")]
fn l(args: &PyTuple) -> PyResult<List> {
    let mut list = List::new();

    for element in args.iter() {
        let element = element.extract::<PyObject>()?;
        list = list.push_front(element)?;
    }
    Ok(list)
}

pub fn py_binding(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<List>()?;
    m.add_wrapped(wrap_pyfunction!(plist)).unwrap();
    m.add_wrapped(wrap_pyfunction!(l)).unwrap();

    Ok(())
}