1use super::{
2 PyStr, PyTupleRef, PyType, PyTypeRef, genericalias::PyGenericAlias,
3 interpolation::PyInterpolation,
4};
5use crate::{
6 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
7 atomic_func,
8 class::PyClassImpl,
9 common::lock::LazyLock,
10 function::{FuncArgs, PyComparisonValue},
11 protocol::{PyIterReturn, PySequenceMethods},
12 types::{
13 AsSequence, Comparable, Constructor, IterNext, Iterable, PyComparisonOp, Representable,
14 SelfIter,
15 },
16};
17use rustpython_common::wtf8::{Wtf8Buf, wtf8_concat};
18
19#[pyclass(module = "string.templatelib", name = "Template")]
23#[derive(Debug, Clone)]
24pub struct PyTemplate {
25 pub strings: PyTupleRef,
26 pub interpolations: PyTupleRef,
27}
28
29impl PyPayload for PyTemplate {
30 #[inline]
31 fn class(ctx: &Context) -> &'static Py<PyType> {
32 ctx.types.template_type
33 }
34}
35
36impl PyTemplate {
37 pub fn new(strings: PyTupleRef, interpolations: PyTupleRef) -> Self {
38 Self {
39 strings,
40 interpolations,
41 }
42 }
43}
44
45impl Constructor for PyTemplate {
46 type Args = FuncArgs;
47
48 fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
49 if !args.kwargs.is_empty() {
50 return Err(vm.new_type_error("Template.__new__ only accepts *args arguments"));
51 }
52
53 let mut strings: Vec<PyObjectRef> = Vec::new();
54 let mut interpolations: Vec<PyObjectRef> = Vec::new();
55 let mut last_was_str = false;
56
57 for item in args.args.iter() {
58 if let Ok(s) = item.clone().downcast::<PyStr>() {
59 if last_was_str {
60 if let Some(last) = strings.last_mut() {
62 let last_str = last.downcast_ref::<PyStr>().unwrap();
63 let mut buf = last_str.as_wtf8().to_owned();
64 buf.push_wtf8(s.as_wtf8());
65 *last = vm.ctx.new_str(buf).into();
66 }
67 } else {
68 strings.push(s.into());
69 }
70 last_was_str = true;
71 } else if item.class().is(vm.ctx.types.interpolation_type) {
72 if !last_was_str {
73 strings.push(vm.ctx.empty_str.to_owned().into());
75 }
76 interpolations.push(item.clone());
77 last_was_str = false;
78 } else {
79 return Err(vm.new_type_error(format!(
80 "Template.__new__ *args need to be of type 'str' or 'Interpolation', got {}",
81 item.class().name()
82 )));
83 }
84 }
85
86 if !last_was_str {
87 strings.push(vm.ctx.empty_str.to_owned().into());
89 }
90
91 Ok(PyTemplate {
92 strings: vm.ctx.new_tuple(strings),
93 interpolations: vm.ctx.new_tuple(interpolations),
94 })
95 }
96}
97
98#[pyclass(with(Constructor, Comparable, Iterable, Representable, AsSequence))]
99impl PyTemplate {
100 #[pygetset]
101 fn strings(&self) -> PyTupleRef {
102 self.strings.clone()
103 }
104
105 #[pygetset]
106 fn interpolations(&self) -> PyTupleRef {
107 self.interpolations.clone()
108 }
109
110 #[pygetset]
111 fn values(&self, vm: &VirtualMachine) -> PyTupleRef {
112 let values: Vec<PyObjectRef> = self
113 .interpolations
114 .iter()
115 .map(|interp| {
116 interp
117 .downcast_ref::<PyInterpolation>()
118 .map(|i| i.value.clone())
119 .unwrap_or_else(|| interp.clone())
120 })
121 .collect();
122 vm.ctx.new_tuple(values)
123 }
124
125 fn concat(&self, other: &PyObject, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
126 let other = other.downcast_ref::<PyTemplate>().ok_or_else(|| {
127 vm.new_type_error(format!(
128 "can only concatenate Template (not '{}') to Template",
129 other.class().name()
130 ))
131 })?;
132
133 let mut new_strings: Vec<PyObjectRef> = Vec::new();
135 let mut new_interps: Vec<PyObjectRef> = Vec::new();
136
137 let self_strings_len = self.strings.len();
139 for i in 0..self_strings_len.saturating_sub(1) {
140 new_strings.push(self.strings.get(i).unwrap().clone());
141 }
142
143 for interp in self.interpolations.iter() {
145 new_interps.push(interp.clone());
146 }
147
148 let mut buf = Wtf8Buf::new();
150 if let Some(s) = self
151 .strings
152 .get(self_strings_len.saturating_sub(1))
153 .and_then(|s| s.downcast_ref::<PyStr>())
154 {
155 buf.push_wtf8(s.as_wtf8());
156 }
157 if let Some(s) = other
158 .strings
159 .first()
160 .and_then(|s| s.downcast_ref::<PyStr>())
161 {
162 buf.push_wtf8(s.as_wtf8());
163 }
164 new_strings.push(vm.ctx.new_str(buf).into());
165
166 for i in 1..other.strings.len() {
168 new_strings.push(other.strings.get(i).unwrap().clone());
169 }
170
171 for interp in other.interpolations.iter() {
173 new_interps.push(interp.clone());
174 }
175
176 let template = PyTemplate {
177 strings: vm.ctx.new_tuple(new_strings),
178 interpolations: vm.ctx.new_tuple(new_interps),
179 };
180
181 Ok(template.into_ref(&vm.ctx))
182 }
183
184 fn __add__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
185 self.concat(&other, vm)
186 }
187
188 #[pyclassmethod]
189 fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
190 PyGenericAlias::from_args(cls, args, vm)
191 }
192
193 #[pymethod]
194 fn __reduce__(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
195 let string_mod = vm.import("string.templatelib", 0)?;
199 let templatelib = string_mod.get_attr("templatelib", vm)?;
200 let unpickle_func = templatelib.get_attr("_template_unpickle", vm)?;
201
202 let args = vm.ctx.new_tuple(vec![
204 self.strings.clone().into(),
205 self.interpolations.clone().into(),
206 ]);
207 Ok(vm.ctx.new_tuple(vec![unpickle_func, args.into()]))
208 }
209}
210
211impl AsSequence for PyTemplate {
212 fn as_sequence() -> &'static PySequenceMethods {
213 static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
214 concat: atomic_func!(|seq, other, vm| {
215 let zelf = PyTemplate::sequence_downcast(seq);
216 zelf.concat(other, vm).map(|t| t.into())
217 }),
218 ..PySequenceMethods::NOT_IMPLEMENTED
219 });
220 &AS_SEQUENCE
221 }
222}
223
224impl Comparable for PyTemplate {
225 fn cmp(
226 zelf: &Py<Self>,
227 other: &PyObject,
228 op: PyComparisonOp,
229 vm: &VirtualMachine,
230 ) -> PyResult<PyComparisonValue> {
231 op.eq_only(|| {
232 let other = class_or_notimplemented!(Self, other);
233
234 let eq = vm.bool_eq(zelf.strings.as_object(), other.strings.as_object())?
235 && vm.bool_eq(
236 zelf.interpolations.as_object(),
237 other.interpolations.as_object(),
238 )?;
239
240 Ok(eq.into())
241 })
242 }
243}
244
245impl Iterable for PyTemplate {
246 fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
247 Ok(PyTemplateIter::new(zelf).into_pyobject(vm))
248 }
249}
250
251impl Representable for PyTemplate {
252 #[inline]
253 fn repr_wtf8(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
254 let strings_repr = zelf.strings.as_object().repr(vm)?;
255 let interp_repr = zelf.interpolations.as_object().repr(vm)?;
256 Ok(wtf8_concat!(
257 "Template(strings=",
258 strings_repr.as_wtf8(),
259 ", interpolations=",
260 interp_repr.as_wtf8(),
261 ')',
262 ))
263 }
264}
265
266#[pyclass(module = "string.templatelib", name = "TemplateIter")]
268#[derive(Debug)]
269pub struct PyTemplateIter {
270 template: PyRef<PyTemplate>,
271 index: core::sync::atomic::AtomicUsize,
272 from_strings: core::sync::atomic::AtomicBool,
273}
274
275impl PyPayload for PyTemplateIter {
276 #[inline]
277 fn class(ctx: &Context) -> &'static Py<PyType> {
278 ctx.types.template_iter_type
279 }
280}
281
282impl PyTemplateIter {
283 fn new(template: PyRef<PyTemplate>) -> Self {
284 Self {
285 template,
286 index: core::sync::atomic::AtomicUsize::new(0),
287 from_strings: core::sync::atomic::AtomicBool::new(true),
288 }
289 }
290}
291
292#[pyclass(with(IterNext, Iterable))]
293impl PyTemplateIter {}
294
295impl SelfIter for PyTemplateIter {}
296
297impl IterNext for PyTemplateIter {
298 fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
299 use core::sync::atomic::Ordering;
300
301 loop {
302 let from_strings = zelf.from_strings.load(Ordering::SeqCst);
303 let index = zelf.index.load(Ordering::SeqCst);
304
305 if from_strings {
306 if index < zelf.template.strings.len() {
307 let item = zelf.template.strings.get(index).unwrap();
308 zelf.from_strings.store(false, Ordering::SeqCst);
309
310 if let Some(s) = item.downcast_ref::<PyStr>()
312 && s.as_wtf8().is_empty()
313 {
314 continue;
315 }
316 return Ok(PyIterReturn::Return(item.clone()));
317 } else {
318 return Ok(PyIterReturn::StopIteration(None));
319 }
320 } else if index < zelf.template.interpolations.len() {
321 let item = zelf.template.interpolations.get(index).unwrap();
322 zelf.index.fetch_add(1, Ordering::SeqCst);
323 zelf.from_strings.store(true, Ordering::SeqCst);
324 return Ok(PyIterReturn::Return(item.clone()));
325 } else {
326 return Ok(PyIterReturn::StopIteration(None));
327 }
328 }
329 }
330}
331
332pub fn init(context: &'static Context) {
333 PyTemplate::extend_class(context, context.types.template_type);
334 PyTemplateIter::extend_class(context, context.types.template_iter_type);
335}