Skip to main content

reifydb_routine/function/text/
replace.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
5use reifydb_type::{
6	util::bitvec::BitVec,
7	value::{constraint::bytes::MaxBytes, container::utf8::Utf8Container, r#type::Type},
8};
9
10use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
11
12pub struct TextReplace {
13	info: RoutineInfo,
14}
15
16impl Default for TextReplace {
17	fn default() -> Self {
18		Self::new()
19	}
20}
21
22impl TextReplace {
23	pub fn new() -> Self {
24		Self {
25			info: RoutineInfo::new("text::replace"),
26		}
27	}
28}
29
30impl<'a> Routine<FunctionContext<'a>> for TextReplace {
31	fn info(&self) -> &RoutineInfo {
32		&self.info
33	}
34
35	fn return_type(&self, _input_types: &[Type]) -> Type {
36		Type::Utf8
37	}
38
39	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
40		if args.len() != 3 {
41			return Err(RoutineError::FunctionArityMismatch {
42				function: ctx.fragment.clone(),
43				expected: 3,
44				actual: args.len(),
45			});
46		}
47
48		let str_col = &args[0];
49		let from_col = &args[1];
50		let to_col = &args[2];
51
52		let (str_data, str_bv) = str_col.unwrap_option();
53		let (from_data, from_bv) = from_col.unwrap_option();
54		let (to_data, to_bv) = to_col.unwrap_option();
55		let row_count = str_data.len();
56
57		match (str_data, from_data, to_data) {
58			(
59				ColumnBuffer::Utf8 {
60					container: str_container,
61					..
62				},
63				ColumnBuffer::Utf8 {
64					container: from_container,
65					..
66				},
67				ColumnBuffer::Utf8 {
68					container: to_container,
69					..
70				},
71			) => {
72				let mut result_data = Vec::with_capacity(row_count);
73
74				for i in 0..row_count {
75					if str_container.is_defined(i)
76						&& from_container.is_defined(i) && to_container.is_defined(i)
77					{
78						let s = str_container.get(i).unwrap();
79						let from = from_container.get(i).unwrap();
80						let to = to_container.get(i).unwrap();
81						result_data.push(s.replace(from, to));
82					} else {
83						result_data.push(String::new());
84					}
85				}
86
87				let result_col_data = ColumnBuffer::Utf8 {
88					container: Utf8Container::new(result_data),
89					max_bytes: MaxBytes::MAX,
90				};
91
92				// Combine all three bitvecs
93				let mut combined_bv: Option<BitVec> = None;
94				for bv in [str_bv, from_bv, to_bv].into_iter().flatten() {
95					combined_bv = Some(match combined_bv {
96						Some(existing) => existing.and(bv),
97						None => bv.clone(),
98					});
99				}
100
101				let final_data = match combined_bv {
102					Some(bv) => ColumnBuffer::Option {
103						inner: Box::new(result_col_data),
104						bitvec: bv,
105					},
106					None => result_col_data,
107				};
108				Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
109			}
110			(
111				ColumnBuffer::Utf8 {
112					..
113				},
114				ColumnBuffer::Utf8 {
115					..
116				},
117				other,
118			) => Err(RoutineError::FunctionInvalidArgumentType {
119				function: ctx.fragment.clone(),
120				argument_index: 2,
121				expected: vec![Type::Utf8],
122				actual: other.get_type(),
123			}),
124			(
125				ColumnBuffer::Utf8 {
126					..
127				},
128				other,
129				_,
130			) => Err(RoutineError::FunctionInvalidArgumentType {
131				function: ctx.fragment.clone(),
132				argument_index: 1,
133				expected: vec![Type::Utf8],
134				actual: other.get_type(),
135			}),
136			(other, _, _) => Err(RoutineError::FunctionInvalidArgumentType {
137				function: ctx.fragment.clone(),
138				argument_index: 0,
139				expected: vec![Type::Utf8],
140				actual: other.get_type(),
141			}),
142		}
143	}
144}
145
146impl Function for TextReplace {
147	fn kinds(&self) -> &[FunctionKind] {
148		&[FunctionKind::Scalar]
149	}
150}