Skip to main content

reifydb_routine/function/text/
replace.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData};
5use reifydb_type::{
6	util::bitvec::BitVec,
7	value::{constraint::bytes::MaxBytes, container::utf8::Utf8Container, r#type::Type},
8};
9
10use crate::function::{Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
11
12pub struct TextReplace {
13	info: FunctionInfo,
14}
15
16impl Default for TextReplace {
17	fn default() -> Self {
18		Self::new()
19	}
20}
21
22impl TextReplace {
23	pub fn new() -> Self {
24		Self {
25			info: FunctionInfo::new("text::replace"),
26		}
27	}
28}
29
30impl Function for TextReplace {
31	fn info(&self) -> &FunctionInfo {
32		&self.info
33	}
34
35	fn capabilities(&self) -> &[FunctionCapability] {
36		&[FunctionCapability::Scalar]
37	}
38
39	fn return_type(&self, _input_types: &[Type]) -> Type {
40		Type::Utf8
41	}
42
43	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
44		if args.len() != 3 {
45			return Err(FunctionError::ArityMismatch {
46				function: ctx.fragment.clone(),
47				expected: 3,
48				actual: args.len(),
49			});
50		}
51
52		let str_col = &args[0];
53		let from_col = &args[1];
54		let to_col = &args[2];
55
56		let (str_data, str_bv) = str_col.data().unwrap_option();
57		let (from_data, from_bv) = from_col.data().unwrap_option();
58		let (to_data, to_bv) = to_col.data().unwrap_option();
59		let row_count = str_data.len();
60
61		match (str_data, from_data, to_data) {
62			(
63				ColumnData::Utf8 {
64					container: str_container,
65					..
66				},
67				ColumnData::Utf8 {
68					container: from_container,
69					..
70				},
71				ColumnData::Utf8 {
72					container: to_container,
73					..
74				},
75			) => {
76				let mut result_data = Vec::with_capacity(row_count);
77
78				for i in 0..row_count {
79					if str_container.is_defined(i)
80						&& from_container.is_defined(i) && to_container.is_defined(i)
81					{
82						let s = &str_container[i];
83						let from = &from_container[i];
84						let to = &to_container[i];
85						result_data.push(s.replace(from.as_str(), to.as_str()));
86					} else {
87						result_data.push(String::new());
88					}
89				}
90
91				let result_col_data = ColumnData::Utf8 {
92					container: Utf8Container::new(result_data),
93					max_bytes: MaxBytes::MAX,
94				};
95
96				// Combine all three bitvecs
97				let mut combined_bv: Option<BitVec> = None;
98				for bv in [str_bv, from_bv, to_bv].into_iter().flatten() {
99					combined_bv = Some(match combined_bv {
100						Some(existing) => existing.and(bv),
101						None => bv.clone(),
102					});
103				}
104
105				let final_data = match combined_bv {
106					Some(bv) => ColumnData::Option {
107						inner: Box::new(result_col_data),
108						bitvec: bv,
109					},
110					None => result_col_data,
111				};
112				Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), final_data)]))
113			}
114			(
115				ColumnData::Utf8 {
116					..
117				},
118				ColumnData::Utf8 {
119					..
120				},
121				other,
122			) => Err(FunctionError::InvalidArgumentType {
123				function: ctx.fragment.clone(),
124				argument_index: 2,
125				expected: vec![Type::Utf8],
126				actual: other.get_type(),
127			}),
128			(
129				ColumnData::Utf8 {
130					..
131				},
132				other,
133				_,
134			) => Err(FunctionError::InvalidArgumentType {
135				function: ctx.fragment.clone(),
136				argument_index: 1,
137				expected: vec![Type::Utf8],
138				actual: other.get_type(),
139			}),
140			(other, _, _) => Err(FunctionError::InvalidArgumentType {
141				function: ctx.fragment.clone(),
142				argument_index: 0,
143				expected: vec![Type::Utf8],
144				actual: other.get_type(),
145			}),
146		}
147	}
148}