use crate::ffi_ptr_ext::FfiPtrExt;
use crate::impl_::pyclass::{PyClassBaseType, PyClassImpl};
use crate::impl_::pyclass_init::{PyNativeTypeInitializer, PyObjectInit};
use crate::pycell::impl_::PyClassObjectLayout;
use crate::{ffi, Bound, PyClass, PyResult, Python};
use crate::{ffi::PyTypeObject, pycell::impl_::PyClassObjectContents};
use std::marker::PhantomData;
pub struct PyClassInitializer<T: PyClass> {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
}
impl<T: PyClass> PyClassInitializer<T> {
#[track_caller]
#[inline]
pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self {
Self { init, super_init }
}
#[track_caller]
#[inline]
pub fn add_subclass<S>(self, subclass_value: S) -> PyClassInitializer<S>
where
T: PyClassBaseType<Initializer = Self>,
S: PyClass<BaseType = T>,
{
PyClassInitializer::new(subclass_value, self)
}
pub(crate) fn create_class_object(self, py: Python<'_>) -> PyResult<Bound<'_, T>>
where
T: PyClass,
{
unsafe { self.create_class_object_of_type(py, T::type_object_raw(py)) }
}
pub(crate) unsafe fn create_class_object_of_type(
self,
py: Python<'_>,
target_type: *mut crate::ffi::PyTypeObject,
) -> PyResult<Bound<'_, T>>
where
T: PyClass,
{
let obj = unsafe { self.super_init.into_new_object(py, target_type)? };
let contents = unsafe { <T as PyClassImpl>::Layout::contents_uninit(obj) };
unsafe { (*contents).write(PyClassObjectContents::new(self.init)) };
Ok(unsafe { obj.assume_owned(py).cast_into_unchecked() })
}
}
impl<T: PyClass> PyObjectInit<T> for PyClassInitializer<T> {
unsafe fn into_new_object(
self,
py: Python<'_>,
subtype: *mut PyTypeObject,
) -> PyResult<*mut ffi::PyObject> {
unsafe {
self.create_class_object_of_type(py, subtype)
.map(Bound::into_ptr)
}
}
}
impl<T> From<T> for PyClassInitializer<T>
where
T: PyClass,
T::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<T::BaseType>>,
{
#[inline]
fn from(value: T) -> PyClassInitializer<T> {
Self::new(value, PyNativeTypeInitializer(PhantomData))
}
}
impl<S, B> From<(S, B)> for PyClassInitializer<S>
where
S: PyClass<BaseType = B>,
B: PyClass + PyClassBaseType<Initializer = PyClassInitializer<B>>,
B::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<B::BaseType>>,
{
#[track_caller]
#[inline]
fn from(sub_and_base: (S, B)) -> PyClassInitializer<S> {
let (sub, base) = sub_and_base;
PyClassInitializer::from(base).add_subclass(sub)
}
}