Skip to main content

rustpython_vm/
class.rs

1//! Utilities to define a new Python class
2
3use crate::{
4    PyPayload,
5    builtins::{PyBaseObject, PyType, PyTypeRef, descriptor::PyWrapper},
6    function::PyMethodDef,
7    object::Py,
8    types::{PyTypeFlags, PyTypeSlots, SLOT_DEFS, hash_not_implemented},
9    vm::Context,
10};
11use rustpython_common::static_cell;
12
13/// Add slot wrapper descriptors to a type's dict
14///
15/// Iterates SLOT_DEFS and creates a PyWrapper for each slot that:
16/// 1. Has a function set in the type's slots
17/// 2. Doesn't already have an attribute in the type's dict
18pub fn add_operators(class: &'static Py<PyType>, ctx: &Context) {
19    for def in SLOT_DEFS.iter() {
20        // Skip __new__ - it has special handling
21        if def.name == "__new__" {
22            continue;
23        }
24
25        // Special handling for __hash__ = None
26        if def.name == "__hash__"
27            && class
28                .slots
29                .hash
30                .load()
31                .is_some_and(|h| h as usize == hash_not_implemented as *const () as usize)
32        {
33            class.set_attr(ctx.names.__hash__, ctx.none.clone().into());
34            continue;
35        }
36
37        // __getattr__ should only have a wrapper if the type explicitly defines it.
38        // Unlike __getattribute__, __getattr__ is not present on object by default.
39        // Both map to TpGetattro, but only __getattribute__ gets a wrapper from the slot.
40        if def.name == "__getattr__" {
41            continue;
42        }
43
44        // Get the slot function wrapped in SlotFunc
45        let Some(slot_func) = def.accessor.get_slot_func_with_op(&class.slots, def.op) else {
46            continue;
47        };
48
49        // Check if attribute already exists in dict
50        let attr_name = ctx.intern_str(def.name);
51        if class.attributes.read().contains_key(attr_name) {
52            continue;
53        }
54
55        // Create and add the wrapper
56        let wrapper = PyWrapper {
57            typ: class,
58            name: attr_name,
59            wrapped: slot_func,
60            doc: Some(def.doc),
61        };
62        class.set_attr(attr_name, wrapper.into_ref(ctx).into());
63    }
64}
65
66pub trait StaticType {
67    // Ideally, saving PyType is better than PyTypeRef
68    fn static_cell() -> &'static static_cell::StaticCell<PyTypeRef>;
69    #[inline]
70    fn static_metaclass() -> &'static Py<PyType> {
71        PyType::static_type()
72    }
73    #[inline]
74    fn static_baseclass() -> &'static Py<PyType> {
75        PyBaseObject::static_type()
76    }
77    #[inline]
78    fn static_type() -> &'static Py<PyType> {
79        #[cold]
80        fn fail() -> ! {
81            panic!(
82                "static type has not been initialized. e.g. the native types defined in different module may be used before importing library."
83            );
84        }
85        Self::static_cell().get().unwrap_or_else(|| fail())
86    }
87    fn init_manually(typ: PyTypeRef) -> &'static Py<PyType> {
88        let cell = Self::static_cell();
89        cell.set(typ)
90            .unwrap_or_else(|_| panic!("double initialization from init_manually"));
91        cell.get().unwrap()
92    }
93    fn init_builtin_type() -> &'static Py<PyType>
94    where
95        Self: PyClassImpl,
96    {
97        let typ = Self::create_static_type();
98        let cell = Self::static_cell();
99        cell.set(typ)
100            .unwrap_or_else(|_| panic!("double initialization of {}", Self::NAME));
101        cell.get().unwrap()
102    }
103    fn create_static_type() -> PyTypeRef
104    where
105        Self: PyClassImpl,
106    {
107        PyType::new_static(
108            Self::static_baseclass().to_owned(),
109            Default::default(),
110            Self::make_slots(),
111            Self::static_metaclass().to_owned(),
112        )
113        .unwrap()
114    }
115}
116
117pub trait PyClassDef {
118    const NAME: &'static str;
119    const MODULE_NAME: Option<&'static str>;
120    const TP_NAME: &'static str;
121    const DOC: Option<&'static str> = None;
122    const BASICSIZE: usize;
123    const ITEMSIZE: usize = 0;
124    const UNHASHABLE: bool = false;
125
126    // due to restriction of rust trait system, object.__base__ is None
127    // but PyBaseObject::Base will be PyBaseObject.
128    type Base: PyClassDef;
129}
130
131pub trait PyClassImpl: PyClassDef {
132    const TP_FLAGS: PyTypeFlags = PyTypeFlags::DEFAULT;
133
134    fn extend_class(ctx: &'static Context, class: &'static Py<PyType>)
135    where
136        Self: Sized,
137    {
138        #[cfg(debug_assertions)]
139        {
140            assert!(class.slots.flags.is_created_with_flags());
141        }
142
143        let _ = ctx.intern_str(Self::NAME); // intern type name
144
145        if Self::TP_FLAGS.has_feature(PyTypeFlags::HAS_DICT) {
146            let __dict__ = identifier!(ctx, __dict__);
147            class.set_attr(
148                __dict__,
149                ctx.new_static_getset(
150                    "__dict__",
151                    class,
152                    crate::builtins::object::object_get_dict,
153                    crate::builtins::object::object_set_dict,
154                )
155                .into(),
156            );
157        }
158        Self::impl_extend_class(ctx, class);
159        if let Some(doc) = Self::DOC {
160            // Only set __doc__ if it doesn't already exist (e.g., as a member descriptor)
161            // This matches CPython's behavior in type_dict_set_doc
162            let doc_attr_name = identifier!(ctx, __doc__);
163            if class.attributes.read().get(doc_attr_name).is_none() {
164                class.set_attr(doc_attr_name, ctx.new_str(doc).into());
165            }
166        }
167        if let Some(module_name) = Self::MODULE_NAME {
168            let module_key = identifier!(ctx, __module__);
169            // Don't overwrite a getset descriptor for __module__ (e.g. TypeAliasType
170            // has an instance-level __module__ getset that should not be replaced)
171            let has_getset = class
172                .attributes
173                .read()
174                .get(module_key)
175                .is_some_and(|v| v.downcastable::<crate::builtins::PyGetSet>());
176            if !has_getset {
177                class.set_attr(module_key, ctx.new_str(module_name).into());
178            }
179        }
180
181        // Don't add __new__ attribute if slot_new is inherited from object
182        // (Python doesn't add __new__ to __dict__ for inherited slots)
183        // Exception: object itself should have __new__ in its dict
184        if let Some(slot_new) = class.slots.new.load() {
185            let object_new = ctx.types.object_type.slots.new.load();
186            let is_object_itself = core::ptr::eq(class, ctx.types.object_type);
187            let is_inherited_from_object = !is_object_itself
188                && object_new.is_some_and(|obj_new| slot_new as usize == obj_new as usize);
189
190            if !is_inherited_from_object {
191                let bound_new =
192                    ctx.slot_new_wrapper
193                        .build_bound_method(ctx, class.to_owned().into(), class);
194                class.set_attr(identifier!(ctx, __new__), bound_new.into());
195            }
196        }
197
198        // Add slot wrappers using SLOT_DEFS array
199        add_operators(class, ctx);
200
201        // Inherit slots from base types after slots are fully initialized
202        for base in class.bases.read().iter() {
203            class.inherit_slots(base);
204        }
205
206        class.extend_methods(class.slots.methods, ctx);
207    }
208
209    fn make_static_type() -> PyTypeRef
210    where
211        Self: StaticType + Sized,
212    {
213        (*Self::static_cell().get_or_init(|| {
214            let typ = Self::create_static_type();
215            Self::extend_class(Context::genesis(), unsafe {
216                // typ will be saved in static_cell
217                let r: &Py<PyType> = &typ;
218                let r: &'static Py<PyType> = core::mem::transmute(r);
219                r
220            });
221            typ
222        }))
223        .to_owned()
224    }
225
226    fn impl_extend_class(ctx: &'static Context, class: &'static Py<PyType>);
227    const METHOD_DEFS: &'static [PyMethodDef];
228    fn extend_slots(slots: &mut PyTypeSlots);
229
230    fn make_slots() -> PyTypeSlots {
231        let mut slots = PyTypeSlots {
232            flags: Self::TP_FLAGS,
233            name: Self::TP_NAME,
234            basicsize: Self::BASICSIZE,
235            itemsize: Self::ITEMSIZE,
236            doc: Self::DOC,
237            methods: Self::METHOD_DEFS,
238            ..Default::default()
239        };
240
241        if Self::UNHASHABLE {
242            slots.hash.store(Some(hash_not_implemented));
243        }
244
245        Self::extend_slots(&mut slots);
246        slots
247    }
248}
249
250/// Trait for Python subclasses that can provide a reference to their base type.
251///
252/// This trait is automatically implemented by the `#[pyclass]` macro when
253/// `base = SomeType` is specified. It provides safe reference access to the
254/// base type's payload.
255///
256/// For subclasses with `#[repr(transparent)]`
257/// which enables ownership transfer via `into_base()`.
258pub trait PySubclass: crate::PyPayload {
259    type Base: crate::PyPayload;
260
261    /// Returns a reference to the base type's payload.
262    fn as_base(&self) -> &Self::Base;
263}