1use rustpython_common::wtf8::{Wtf8, Wtf8Buf};
2
3use crate::{
4 AsObject, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, VirtualMachine,
5 builtins::{PyStr, PyStrInterned, PyTypeRef},
6 common::lock::PyRwLock,
7 convert::ToPyObject,
8};
9use alloc::borrow::ToOwned;
10use core::{borrow::Borrow, ops::Deref};
11
12#[derive(Debug)]
13pub struct StringPool {
14 inner: PyRwLock<std::collections::HashSet<CachedPyStrRef, ahash::RandomState>>,
15}
16
17impl Default for StringPool {
18 fn default() -> Self {
19 Self {
20 inner: PyRwLock::new(Default::default()),
21 }
22 }
23}
24
25impl Clone for StringPool {
26 fn clone(&self) -> Self {
27 Self {
28 inner: PyRwLock::new(self.inner.read().clone()),
29 }
30 }
31}
32
33impl StringPool {
34 #[cfg(all(unix, feature = "threading"))]
40 pub(crate) unsafe fn reinit_after_fork(&self) {
41 unsafe { crate::common::lock::reinit_rwlock_after_fork(&self.inner) };
42 }
43
44 #[inline]
45 pub unsafe fn intern<S: InternableString>(
46 &self,
47 s: S,
48 typ: PyTypeRef,
49 ) -> &'static PyStrInterned {
50 if let Some(found) = self.interned(s.as_ref()) {
51 return found;
52 }
53
54 #[cold]
55 fn miss(zelf: &StringPool, s: PyRefExact<PyStr>) -> &'static PyStrInterned {
56 let cache = CachedPyStrRef { inner: s };
57 let inserted = zelf.inner.write().insert(cache.clone());
58 if inserted {
59 let interned = unsafe { cache.as_interned_str() };
60 unsafe { interned.as_object().mark_intern() };
61 interned
62 } else {
63 unsafe {
64 zelf.inner
65 .read()
66 .get(cache.as_ref())
67 .expect("inserted is false")
68 .as_interned_str()
69 }
70 }
71 }
72 let str_ref = s.into_pyref_exact(typ);
73 miss(self, str_ref)
74 }
75
76 #[inline]
77 pub fn interned<S: MaybeInternedString + ?Sized>(
78 &self,
79 s: &S,
80 ) -> Option<&'static PyStrInterned> {
81 if let Some(interned) = s.as_interned() {
82 return Some(interned);
83 }
84 self.inner
85 .read()
86 .get(s.as_ref())
87 .map(|cached| unsafe { cached.as_interned_str() })
88 }
89}
90
91#[derive(Debug, Clone)]
92#[repr(transparent)]
93pub struct CachedPyStrRef {
94 inner: PyRefExact<PyStr>,
95}
96
97impl core::hash::Hash for CachedPyStrRef {
98 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
99 self.inner.as_wtf8().hash(state)
100 }
101}
102
103impl PartialEq for CachedPyStrRef {
104 fn eq(&self, other: &Self) -> bool {
105 self.inner.as_wtf8() == other.inner.as_wtf8()
106 }
107}
108
109impl Eq for CachedPyStrRef {}
110
111impl core::borrow::Borrow<Wtf8> for CachedPyStrRef {
112 #[inline]
113 fn borrow(&self) -> &Wtf8 {
114 self.as_wtf8()
115 }
116}
117
118impl AsRef<Wtf8> for CachedPyStrRef {
119 #[inline]
120 fn as_ref(&self) -> &Wtf8 {
121 self.as_wtf8()
122 }
123}
124
125impl CachedPyStrRef {
126 #[inline]
129 const unsafe fn as_interned_str(&self) -> &'static PyStrInterned {
130 unsafe { core::mem::transmute_copy(self) }
131 }
132
133 #[inline]
134 fn as_wtf8(&self) -> &Wtf8 {
135 self.inner.as_wtf8()
136 }
137}
138
139#[repr(transparent)]
140pub struct PyInterned<T> {
141 inner: Py<T>,
142}
143
144impl PyInterned<PyStr> {
145 #[inline]
151 pub fn as_str(&self) -> &str {
152 self.inner
153 .to_str()
154 .unwrap_or_else(|| panic!("interned str is always valid UTF-8"))
155 }
156}
157
158impl<T: PyPayload> PyInterned<T> {
159 #[inline]
160 pub fn leak(cache: PyRef<T>) -> &'static Self {
161 unsafe { core::mem::transmute(cache) }
162 }
163
164 #[inline]
165 const fn as_ptr(&self) -> *const Py<T> {
166 self as *const _ as *const _
167 }
168
169 #[inline]
170 pub fn to_owned(&'static self) -> PyRef<T> {
171 unsafe { (*(&self as *const _ as *const PyRef<T>)).clone() }
172 }
173
174 #[inline]
175 pub fn to_object(&'static self) -> PyObjectRef {
176 self.to_owned().into()
177 }
178}
179
180impl<T: PyPayload> Borrow<PyObject> for PyInterned<T> {
181 #[inline(always)]
182 fn borrow(&self) -> &PyObject {
183 self.inner.borrow()
184 }
185}
186
187impl<T: PyPayload> core::hash::Hash for PyInterned<T> {
190 #[inline(always)]
191 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
192 self.get_id().hash(state)
193 }
194}
195
196impl<T> AsRef<Py<T>> for PyInterned<T> {
197 #[inline(always)]
198 fn as_ref(&self) -> &Py<T> {
199 &self.inner
200 }
201}
202
203impl<T> Deref for PyInterned<T> {
204 type Target = Py<T>;
205 #[inline(always)]
206 fn deref(&self) -> &Self::Target {
207 &self.inner
208 }
209}
210
211impl<T: PyPayload> PartialEq for PyInterned<T> {
212 #[inline(always)]
213 fn eq(&self, other: &Self) -> bool {
214 core::ptr::eq(self, other)
215 }
216}
217
218impl<T: PyPayload> Eq for PyInterned<T> {}
219
220impl<T: core::fmt::Debug + PyPayload> core::fmt::Debug for PyInterned<T> {
221 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
222 core::fmt::Debug::fmt(&**self, f)?;
223 write!(f, "@{:p}", self.as_ptr())
224 }
225}
226
227impl<T: PyPayload> ToPyObject for &'static PyInterned<T> {
228 fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef {
229 self.to_owned().into()
230 }
231}
232
233mod sealed {
234 use rustpython_common::wtf8::{Wtf8, Wtf8Buf};
235
236 use crate::{
237 builtins::PyStr,
238 object::{Py, PyExact, PyRefExact},
239 };
240
241 pub trait SealedInternable {}
242
243 impl SealedInternable for String {}
244 impl SealedInternable for &str {}
245 impl SealedInternable for Wtf8Buf {}
246 impl SealedInternable for &Wtf8 {}
247 impl SealedInternable for PyRefExact<PyStr> {}
248
249 pub trait SealedMaybeInterned {}
250
251 impl SealedMaybeInterned for str {}
252 impl SealedMaybeInterned for Wtf8 {}
253 impl SealedMaybeInterned for PyExact<PyStr> {}
254 impl SealedMaybeInterned for Py<PyStr> {}
255}
256
257pub trait InternableString: sealed::SealedInternable + ToPyObject + AsRef<Self::Interned> {
259 type Interned: MaybeInternedString + ?Sized;
260 fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact<PyStr>;
261}
262
263impl InternableString for String {
264 type Interned = str;
265 #[inline]
266 fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact<PyStr> {
267 let obj = PyRef::new_ref(PyStr::from(self), str_type, None);
268 unsafe { PyRefExact::new_unchecked(obj) }
269 }
270}
271
272impl InternableString for &str {
273 type Interned = str;
274 #[inline]
275 fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact<PyStr> {
276 self.to_owned().into_pyref_exact(str_type)
277 }
278}
279
280impl InternableString for Wtf8Buf {
281 type Interned = Wtf8;
282 fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact<PyStr> {
283 let obj = PyRef::new_ref(PyStr::from(self), str_type, None);
284 unsafe { PyRefExact::new_unchecked(obj) }
285 }
286}
287
288impl InternableString for &Wtf8 {
289 type Interned = Wtf8;
290 fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact<PyStr> {
291 self.to_owned().into_pyref_exact(str_type)
292 }
293}
294
295impl InternableString for PyRefExact<PyStr> {
296 type Interned = Py<PyStr>;
297 #[inline]
298 fn into_pyref_exact(self, _str_type: PyTypeRef) -> PyRefExact<PyStr> {
299 self
300 }
301}
302
303pub trait MaybeInternedString:
304 AsRef<Wtf8> + crate::dict_inner::DictKey + sealed::SealedMaybeInterned
305{
306 fn as_interned(&self) -> Option<&'static PyStrInterned>;
307}
308
309impl MaybeInternedString for str {
310 #[inline(always)]
311 fn as_interned(&self) -> Option<&'static PyStrInterned> {
312 None
313 }
314}
315
316impl MaybeInternedString for Wtf8 {
317 #[inline(always)]
318 fn as_interned(&self) -> Option<&'static PyStrInterned> {
319 None
320 }
321}
322
323impl MaybeInternedString for PyExact<PyStr> {
324 #[inline(always)]
325 fn as_interned(&self) -> Option<&'static PyStrInterned> {
326 None
327 }
328}
329
330impl MaybeInternedString for Py<PyStr> {
331 #[inline(always)]
332 fn as_interned(&self) -> Option<&'static PyStrInterned> {
333 if self.as_object().is_interned() {
334 Some(unsafe { core::mem::transmute::<&Self, &PyInterned<PyStr>>(self) })
335 } else {
336 None
337 }
338 }
339}
340
341impl PyObject {
342 #[inline]
343 pub fn as_interned_str(&self, vm: &crate::VirtualMachine) -> Option<&'static PyStrInterned> {
344 let s: Option<&Py<PyStr>> = self.downcast_ref();
345 if self.is_interned() {
346 s.unwrap().as_interned()
347 } else if let Some(s) = s {
348 vm.ctx.interned_str(s.as_wtf8())
349 } else {
350 None
351 }
352 }
353}