from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
SparkSession = _imports.SparkSession
F = _imports.F
class TestIssue329LogFloatConstant:
def _get_unique_app_name(self, test_name: str) -> str:
import os
import threading
thread_id = threading.current_thread().ident
process_id = os.getpid()
return f"{test_name}_{process_id}_{thread_id}"
def test_log_with_float_base(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": 100.0},
{"Value": 1000.0},
]
)
result = df.select(
"Value",
F.log(10.0, F.col("Value")).alias("Log10"),
)
rows = result.collect()
assert len(rows) == 2
row1 = [r for r in rows if r["Value"] == 100.0][0]
row2 = [r for r in rows if r["Value"] == 1000.0][0]
assert abs(row1["Log10"] - 2.0) < 0.0001
assert abs(row2["Log10"] - 3.0) < 0.0001
finally:
spark.stop()
def test_log_with_int_base(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": 8.0},
]
)
result = df.select(F.log(2.0, F.col("Value")).alias("Log2"))
rows = result.collect()
assert len(rows) == 1
assert abs(rows[0]["Log2"] - 3.0) < 0.0001
finally:
spark.stop()
def test_log_natural_log(self):
import inspect
import math
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": math.e},
]
)
result = df.select(F.log(F.col("Value")).alias("Ln"))
rows = result.collect()
assert len(rows) == 1
assert abs(rows[0]["Ln"] - 1.0) < 0.0001
finally:
spark.stop()
def test_log_with_different_bases(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": 100.0},
]
)
result = df.select(
F.log(10.0, F.col("Value")).alias("Log10"),
F.log(2.0, F.col("Value")).alias("Log2"),
F.log(3.0, F.col("Value")).alias("Log3"),
)
rows = result.collect()
assert len(rows) == 1
row = rows[0]
assert abs(row["Log10"] - 2.0) < 0.0001
assert abs(row["Log2"] - 6.644) < 0.01
assert abs(row["Log3"] - 4.192) < 0.01
finally:
spark.stop()
def test_log_with_null_values(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": 100.0},
{"Value": None},
]
)
result = df.select(
"Value",
F.log(10.0, F.col("Value")).alias("Log10"),
)
rows = result.collect()
assert len(rows) == 2
row1 = [r for r in rows if r["Value"] == 100.0][0]
row2 = [r for r in rows if r["Value"] is None][0]
assert abs(row1["Log10"] - 2.0) < 0.0001
assert row2["Log10"] is None
finally:
spark.stop()
def test_log_in_with_column(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": 100.0},
]
)
df = df.withColumn("Log10", F.log(10.0, F.col("Value")))
rows = df.collect()
assert len(rows) == 1
assert abs(rows[0]["Log10"] - 2.0) < 0.0001
finally:
spark.stop()
def test_log_edge_cases(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": 1.0},
{"Value": 0.5},
{"Value": 2.0},
]
)
result = df.select(
"Value",
F.log(10.0, F.col("Value")).alias("Log10"),
F.log(F.col("Value")).alias("Ln"),
)
rows = result.collect()
assert len(rows) == 3
row1 = [r for r in rows if r["Value"] == 1.0][0]
assert abs(row1["Log10"]) < 0.0001
assert abs(row1["Ln"]) < 0.0001
row2 = [r for r in rows if r["Value"] == 0.5][0]
assert abs(row2["Log10"] - (-0.301)) < 0.01
row3 = [r for r in rows if r["Value"] == 2.0][0]
assert abs(row3["Log10"] - 0.301) < 0.01
finally:
spark.stop()