import hashlib
import pathlib
import sys
from io import StringIO
template = """\
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This table should be identical with
// https://github.com/apache/arrow/blob/main/cpp/src/parquet/chunker_internal_generated.h
// Ensure that both tables remain in sync after any changes.
#[rustfmt::skip]
pub(crate) const NUM_GEARHASH_TABLES: usize = {ntables};
#[rustfmt::skip]
pub(crate) const GEARHASH_TABLE: [[u64; 256]; NUM_GEARHASH_TABLES] = [
{content}];
"""
def generate_hash(n: int, seed: int):
value = bytes([seed] * 64 + [n] * 64)
hasher = hashlib.md5(value)
return hasher.hexdigest()[:16]
def generate_hashtable(seed: int, length=256):
table = [generate_hash(n, seed=seed) for n in range(length)]
out = StringIO()
out.write(f" // seed = {seed}\n")
out.write(" [\n")
for i in range(0, length, 4):
values = [f"0x{value}" for value in table[i : i + 4]]
out.write(f" {', '.join(values)},\n")
out.write(" ]")
return out.getvalue()
def generate_source(ntables=8, relative_path="cdc_generated.rs"):
path = pathlib.Path(__file__).parent / relative_path
tables = [generate_hashtable(seed) for seed in range(ntables)]
content = ",\n".join(tables)
text = template.format(ntables=ntables, content=content)
path.write_text(text)
if __name__ == "__main__":
ntables = int(sys.argv[1]) if len(sys.argv) > 1 else 8
generate_source(ntables)