Skip to main content

reifydb_routine/function/text/
substring.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
5use reifydb_type::{
6	util::bitvec::BitVec,
7	value::{container::utf8::Utf8Container, r#type::Type},
8};
9
10use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
11
12pub struct TextSubstring {
13	info: RoutineInfo,
14}
15
16impl Default for TextSubstring {
17	fn default() -> Self {
18		Self::new()
19	}
20}
21
22impl TextSubstring {
23	pub fn new() -> Self {
24		Self {
25			info: RoutineInfo::new("text::substring"),
26		}
27	}
28}
29
30impl<'a> Routine<FunctionContext<'a>> for TextSubstring {
31	fn info(&self) -> &RoutineInfo {
32		&self.info
33	}
34
35	fn return_type(&self, _input_types: &[Type]) -> Type {
36		Type::Utf8
37	}
38
39	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
40		if args.len() != 3 {
41			return Err(RoutineError::FunctionArityMismatch {
42				function: ctx.fragment.clone(),
43				expected: 3,
44				actual: args.len(),
45			});
46		}
47
48		let text_col = &args[0];
49		let start_col = &args[1];
50		let length_col = &args[2];
51
52		let (text_data, text_bv) = text_col.unwrap_option();
53		let (start_data, start_bv) = start_col.unwrap_option();
54		let (length_data, length_bv) = length_col.unwrap_option();
55		let row_count = text_data.len();
56
57		match (text_data, start_data, length_data) {
58			(
59				ColumnBuffer::Utf8 {
60					container: text_container,
61					max_bytes,
62				},
63				ColumnBuffer::Int4(start_container),
64				ColumnBuffer::Int4(length_container),
65			) => {
66				let mut result_data = Vec::with_capacity(text_container.len());
67
68				for i in 0..row_count {
69					if text_container.is_defined(i)
70						&& start_container.is_defined(i) && length_container.is_defined(i)
71					{
72						let original_str = text_container.get(i).unwrap();
73						let start_pos = start_container.get(i).copied().unwrap_or(0);
74						let length = length_container.get(i).copied().unwrap_or(0);
75
76						let chars: Vec<char> = original_str.chars().collect();
77						let chars_len = chars.len();
78
79						let start_idx = if start_pos < 0 {
80							chars_len.saturating_sub((-start_pos) as usize)
81						} else {
82							start_pos as usize
83						};
84						let length_usize = if length < 0 {
85							0
86						} else {
87							length as usize
88						};
89
90						let substring = if start_idx >= chars_len {
91							String::new()
92						} else {
93							let end_idx = (start_idx + length_usize).min(chars_len);
94							chars[start_idx..end_idx].iter().collect()
95						};
96
97						result_data.push(substring);
98					} else {
99						result_data.push(String::new());
100					}
101				}
102
103				let result_col_data = ColumnBuffer::Utf8 {
104					container: Utf8Container::new(result_data),
105					max_bytes: *max_bytes,
106				};
107
108				let mut combined_bv: Option<BitVec> = None;
109				for bv in [text_bv, start_bv, length_bv].into_iter().flatten() {
110					combined_bv = Some(match combined_bv {
111						Some(existing) => existing.and(bv),
112						None => bv.clone(),
113					});
114				}
115
116				let final_data = match combined_bv {
117					Some(bv) => ColumnBuffer::Option {
118						inner: Box::new(result_col_data),
119						bitvec: bv,
120					},
121					None => result_col_data,
122				};
123				Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
124			}
125
126			(
127				ColumnBuffer::Utf8 {
128					container: text_container,
129					max_bytes,
130				},
131				start_d,
132				length_d,
133			) => {
134				let mut result_data = Vec::with_capacity(text_container.len());
135
136				for i in 0..row_count {
137					if text_container.is_defined(i) {
138						let original_str = text_container.get(i).unwrap();
139
140						let start_pos = match start_d {
141							ColumnBuffer::Int1(container) => {
142								container.get(i).map(|&v| v as i32).unwrap_or(0)
143							}
144							ColumnBuffer::Int2(container) => {
145								container.get(i).map(|&v| v as i32).unwrap_or(0)
146							}
147							ColumnBuffer::Int4(container) => {
148								container.get(i).copied().unwrap_or(0)
149							}
150							ColumnBuffer::Int8(container) => {
151								container.get(i).map(|&v| v as i32).unwrap_or(0)
152							}
153							_ => 0,
154						};
155
156						let length = match length_d {
157							ColumnBuffer::Int1(container) => {
158								container.get(i).map(|&v| v as i32).unwrap_or(0)
159							}
160							ColumnBuffer::Int2(container) => {
161								container.get(i).map(|&v| v as i32).unwrap_or(0)
162							}
163							ColumnBuffer::Int4(container) => {
164								container.get(i).copied().unwrap_or(0)
165							}
166							ColumnBuffer::Int8(container) => {
167								container.get(i).map(|&v| v as i32).unwrap_or(0)
168							}
169							_ => 0,
170						};
171
172						let chars: Vec<char> = original_str.chars().collect();
173						let chars_len = chars.len();
174
175						let start_idx = if start_pos < 0 {
176							chars_len.saturating_sub((-start_pos) as usize)
177						} else {
178							start_pos as usize
179						};
180						let length_usize = if length < 0 {
181							0
182						} else {
183							length as usize
184						};
185
186						let substring = if start_idx >= chars_len {
187							String::new()
188						} else {
189							let end_idx = (start_idx + length_usize).min(chars_len);
190							chars[start_idx..end_idx].iter().collect()
191						};
192
193						result_data.push(substring);
194					} else {
195						result_data.push(String::new());
196					}
197				}
198
199				let result_col_data = ColumnBuffer::Utf8 {
200					container: Utf8Container::new(result_data),
201					max_bytes: *max_bytes,
202				};
203
204				let mut combined_bv: Option<BitVec> = None;
205				for bv in [text_bv, start_bv, length_bv].into_iter().flatten() {
206					combined_bv = Some(match combined_bv {
207						Some(existing) => existing.and(bv),
208						None => bv.clone(),
209					});
210				}
211
212				let final_data = match combined_bv {
213					Some(bv) => ColumnBuffer::Option {
214						inner: Box::new(result_col_data),
215						bitvec: bv,
216					},
217					None => result_col_data,
218				};
219				Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
220			}
221			(other, _, _) => Err(RoutineError::FunctionInvalidArgumentType {
222				function: ctx.fragment.clone(),
223				argument_index: 0,
224				expected: vec![Type::Utf8],
225				actual: other.get_type(),
226			}),
227		}
228	}
229}
230
231impl Function for TextSubstring {
232	fn kinds(&self) -> &[FunctionKind] {
233		&[FunctionKind::Scalar]
234	}
235}