use super::{
PyStr, PyTupleRef, PyType, PyTypeRef, genericalias::PyGenericAlias,
interpolation::PyInterpolation,
};
use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
atomic_func,
class::PyClassImpl,
common::lock::LazyLock,
function::{FuncArgs, PyComparisonValue},
protocol::{PyIterReturn, PySequenceMethods},
types::{
AsSequence, Comparable, Constructor, IterNext, Iterable, PyComparisonOp, Representable,
SelfIter,
},
};
use rustpython_common::wtf8::{Wtf8Buf, wtf8_concat};
#[pyclass(module = "string.templatelib", name = "Template")]
#[derive(Debug, Clone)]
pub struct PyTemplate {
pub strings: PyTupleRef,
pub interpolations: PyTupleRef,
}
impl PyPayload for PyTemplate {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.template_type
}
}
impl PyTemplate {
pub fn new(strings: PyTupleRef, interpolations: PyTupleRef) -> Self {
Self {
strings,
interpolations,
}
}
}
impl Constructor for PyTemplate {
type Args = FuncArgs;
fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
if !args.kwargs.is_empty() {
return Err(vm.new_type_error("Template.__new__ only accepts *args arguments"));
}
let mut strings: Vec<PyObjectRef> = Vec::new();
let mut interpolations: Vec<PyObjectRef> = Vec::new();
let mut last_was_str = false;
for item in args.args.iter() {
if let Ok(s) = item.clone().downcast::<PyStr>() {
if last_was_str {
if let Some(last) = strings.last_mut() {
let last_str = last.downcast_ref::<PyStr>().unwrap();
let mut buf = last_str.as_wtf8().to_owned();
buf.push_wtf8(s.as_wtf8());
*last = vm.ctx.new_str(buf).into();
}
} else {
strings.push(s.into());
}
last_was_str = true;
} else if item.class().is(vm.ctx.types.interpolation_type) {
if !last_was_str {
strings.push(vm.ctx.empty_str.to_owned().into());
}
interpolations.push(item.clone());
last_was_str = false;
} else {
return Err(vm.new_type_error(format!(
"Template.__new__ *args need to be of type 'str' or 'Interpolation', got {}",
item.class().name()
)));
}
}
if !last_was_str {
strings.push(vm.ctx.empty_str.to_owned().into());
}
Ok(PyTemplate {
strings: vm.ctx.new_tuple(strings),
interpolations: vm.ctx.new_tuple(interpolations),
})
}
}
#[pyclass(with(Constructor, Comparable, Iterable, Representable, AsSequence))]
impl PyTemplate {
#[pygetset]
fn strings(&self) -> PyTupleRef {
self.strings.clone()
}
#[pygetset]
fn interpolations(&self) -> PyTupleRef {
self.interpolations.clone()
}
#[pygetset]
fn values(&self, vm: &VirtualMachine) -> PyTupleRef {
let values: Vec<PyObjectRef> = self
.interpolations
.iter()
.map(|interp| {
interp
.downcast_ref::<PyInterpolation>()
.map(|i| i.value.clone())
.unwrap_or_else(|| interp.clone())
})
.collect();
vm.ctx.new_tuple(values)
}
fn concat(&self, other: &PyObject, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
let other = other.downcast_ref::<PyTemplate>().ok_or_else(|| {
vm.new_type_error(format!(
"can only concatenate Template (not '{}') to Template",
other.class().name()
))
})?;
let mut new_strings: Vec<PyObjectRef> = Vec::new();
let mut new_interps: Vec<PyObjectRef> = Vec::new();
let self_strings_len = self.strings.len();
for i in 0..self_strings_len.saturating_sub(1) {
new_strings.push(self.strings.get(i).unwrap().clone());
}
for interp in self.interpolations.iter() {
new_interps.push(interp.clone());
}
let mut buf = Wtf8Buf::new();
if let Some(s) = self
.strings
.get(self_strings_len.saturating_sub(1))
.and_then(|s| s.downcast_ref::<PyStr>())
{
buf.push_wtf8(s.as_wtf8());
}
if let Some(s) = other
.strings
.first()
.and_then(|s| s.downcast_ref::<PyStr>())
{
buf.push_wtf8(s.as_wtf8());
}
new_strings.push(vm.ctx.new_str(buf).into());
for i in 1..other.strings.len() {
new_strings.push(other.strings.get(i).unwrap().clone());
}
for interp in other.interpolations.iter() {
new_interps.push(interp.clone());
}
let template = PyTemplate {
strings: vm.ctx.new_tuple(new_strings),
interpolations: vm.ctx.new_tuple(new_interps),
};
Ok(template.into_ref(&vm.ctx))
}
fn __add__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
self.concat(&other, vm)
}
#[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
#[pymethod]
fn __reduce__(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
let string_mod = vm.import("string.templatelib", 0)?;
let templatelib = string_mod.get_attr("templatelib", vm)?;
let unpickle_func = templatelib.get_attr("_template_unpickle", vm)?;
let args = vm.ctx.new_tuple(vec![
self.strings.clone().into(),
self.interpolations.clone().into(),
]);
Ok(vm.ctx.new_tuple(vec![unpickle_func, args.into()]))
}
}
impl AsSequence for PyTemplate {
fn as_sequence() -> &'static PySequenceMethods {
static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
concat: atomic_func!(|seq, other, vm| {
let zelf = PyTemplate::sequence_downcast(seq);
zelf.concat(other, vm).map(|t| t.into())
}),
..PySequenceMethods::NOT_IMPLEMENTED
});
&AS_SEQUENCE
}
}
impl Comparable for PyTemplate {
fn cmp(
zelf: &Py<Self>,
other: &PyObject,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
op.eq_only(|| {
let other = class_or_notimplemented!(Self, other);
let eq = vm.bool_eq(zelf.strings.as_object(), other.strings.as_object())?
&& vm.bool_eq(
zelf.interpolations.as_object(),
other.interpolations.as_object(),
)?;
Ok(eq.into())
})
}
}
impl Iterable for PyTemplate {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(PyTemplateIter::new(zelf).into_pyobject(vm))
}
}
impl Representable for PyTemplate {
#[inline]
fn repr_wtf8(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
let strings_repr = zelf.strings.as_object().repr(vm)?;
let interp_repr = zelf.interpolations.as_object().repr(vm)?;
Ok(wtf8_concat!(
"Template(strings=",
strings_repr.as_wtf8(),
", interpolations=",
interp_repr.as_wtf8(),
')',
))
}
}
#[pyclass(module = "string.templatelib", name = "TemplateIter")]
#[derive(Debug)]
pub struct PyTemplateIter {
template: PyRef<PyTemplate>,
index: core::sync::atomic::AtomicUsize,
from_strings: core::sync::atomic::AtomicBool,
}
impl PyPayload for PyTemplateIter {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.template_iter_type
}
}
impl PyTemplateIter {
fn new(template: PyRef<PyTemplate>) -> Self {
Self {
template,
index: core::sync::atomic::AtomicUsize::new(0),
from_strings: core::sync::atomic::AtomicBool::new(true),
}
}
}
#[pyclass(with(IterNext, Iterable))]
impl PyTemplateIter {}
impl SelfIter for PyTemplateIter {}
impl IterNext for PyTemplateIter {
fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
use core::sync::atomic::Ordering;
loop {
let from_strings = zelf.from_strings.load(Ordering::SeqCst);
let index = zelf.index.load(Ordering::SeqCst);
if from_strings {
if index < zelf.template.strings.len() {
let item = zelf.template.strings.get(index).unwrap();
zelf.from_strings.store(false, Ordering::SeqCst);
if let Some(s) = item.downcast_ref::<PyStr>()
&& s.as_wtf8().is_empty()
{
continue;
}
return Ok(PyIterReturn::Return(item.clone()));
} else {
return Ok(PyIterReturn::StopIteration(None));
}
} else if index < zelf.template.interpolations.len() {
let item = zelf.template.interpolations.get(index).unwrap();
zelf.index.fetch_add(1, Ordering::SeqCst);
zelf.from_strings.store(true, Ordering::SeqCst);
return Ok(PyIterReturn::Return(item.clone()));
} else {
return Ok(PyIterReturn::StopIteration(None));
}
}
}
}
pub fn init(context: &'static Context) {
PyTemplate::extend_class(context, context.types.template_type);
PyTemplateIter::extend_class(context, context.types.template_iter_type);
}